From 3bd6b3ce288acad7e915b37d70f1c4e16e506064 Mon Sep 17 00:00:00 2001 From: Zach Reyes Date: Tue, 29 Aug 2023 17:37:05 -0400 Subject: [PATCH 1/2] Change server context reading --- internal/transport/handler_server.go | 2 +- internal/transport/handler_server_test.go | 5 - internal/transport/http2_server.go | 7 +- internal/transport/transport.go | 2 +- internal/transport/transport_test.go | 22 +--- server.go | 136 ++++++++++------------ stats/stats_test.go | 10 -- 7 files changed, 67 insertions(+), 117 deletions(-) diff --git a/internal/transport/handler_server.go b/internal/transport/handler_server.go index 98f80e3fa00a..5e14f934d614 100644 --- a/internal/transport/handler_server.go +++ b/internal/transport/handler_server.go @@ -344,7 +344,7 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error { return err } -func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), traceCtx func(context.Context, string) context.Context) { +func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) { // With this transport type there will be exactly 1 stream: this HTTP request. ctx := ht.req.Context() diff --git a/internal/transport/handler_server_test.go b/internal/transport/handler_server_test.go index 36b0864177e7..bf67480b318f 100644 --- a/internal/transport/handler_server_test.go +++ b/internal/transport/handler_server_test.go @@ -315,7 +315,6 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) { } st.ht.HandleStreams( func(s *Stream) { go handleStream(s) }, - func(ctx context.Context, method string) context.Context { return ctx }, ) wantHeader := http.Header{ "Date": nil, @@ -349,7 +348,6 @@ func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) } st.ht.HandleStreams( func(s *Stream) { go handleStream(s) }, - func(ctx context.Context, method string) context.Context { return ctx }, ) wantHeader := http.Header{ "Date": nil, @@ -399,7 +397,6 @@ func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) { } ht.HandleStreams( func(s *Stream) { go runStream(s) }, - func(ctx context.Context, method string) context.Context { return ctx }, ) wantHeader := http.Header{ "Date": nil, @@ -452,7 +449,6 @@ func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handl st := newHandleStreamTest(t) st.ht.HandleStreams( func(s *Stream) { go handleStream(st, s) }, - func(ctx context.Context, method string) context.Context { return ctx }, ) } @@ -486,7 +482,6 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) { } hst.ht.HandleStreams( func(s *Stream) { go handleStream(s) }, - func(ctx context.Context, method string) context.Context { return ctx }, ) wantHeader := http.Header{ "Date": nil, diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 8d3a353c1d58..cadc64043f47 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -347,7 +347,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, // operateHeaders takes action on the decoded headers. Returns an error if fatal // error encountered and transport needs to close, otherwise returns nil. -func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) error { +func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) error { // Acquire max stream ID lock for entire duration t.maxStreamMu.Lock() defer t.maxStreamMu.Unlock() @@ -597,7 +597,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( s.requestRead = func(n int) { t.adjustWindow(s, uint32(n)) } - s.ctx = traceCtx(s.ctx, s.method) for _, sh := range t.stats { s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) inHeader := &stats.InHeader{ @@ -635,7 +634,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( // HandleStreams receives incoming streams using the given handler. This is // typically run in a separate goroutine. // traceCtx attaches trace to ctx and returns the new context. -func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.Context, string) context.Context) { +func (t *http2Server) HandleStreams(handle func(*Stream)) { defer close(t.readerDone) for { t.controlBuf.throttle() @@ -670,7 +669,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context. } switch frame := frame.(type) { case *http2.MetaHeadersFrame: - if err := t.operateHeaders(frame, handle, traceCtx); err != nil { + if err := t.operateHeaders(frame, handle); err != nil { t.Close(err) break } diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 74a811fc0590..aac056e723bb 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -698,7 +698,7 @@ type ClientTransport interface { // Write methods for a given Stream will be called serially. type ServerTransport interface { // HandleStreams receives incoming streams using the given handler. - HandleStreams(func(*Stream), func(context.Context, string) context.Context) + HandleStreams(func(*Stream)) // WriteHeader sends the header metadata for the given stream. // WriteHeader may not be called on all streams. diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 258ef7411cf0..795687861f57 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -353,32 +353,20 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT s.mu.Unlock() switch ht { case notifyCall: - go transport.HandleStreams(h.handleStreamAndNotify, - func(ctx context.Context, _ string) context.Context { - return ctx - }) + go transport.HandleStreams(h.handleStreamAndNotify) case suspended: - go transport.HandleStreams(func(*Stream) {}, // Do nothing to handle the stream. - func(ctx context.Context, method string) context.Context { - return ctx - }) + go transport.HandleStreams(func(*Stream) {}) case misbehaved: go transport.HandleStreams(func(s *Stream) { go h.handleStreamMisbehave(t, s) - }, func(ctx context.Context, method string) context.Context { - return ctx }) case encodingRequiredStatus: go transport.HandleStreams(func(s *Stream) { go h.handleStreamEncodingRequiredStatus(s) - }, func(ctx context.Context, method string) context.Context { - return ctx }) case invalidHeaderField: go transport.HandleStreams(func(s *Stream) { go h.handleStreamInvalidHeaderField(s) - }, func(ctx context.Context, method string) context.Context { - return ctx }) case delayRead: h.notify = make(chan struct{}) @@ -388,20 +376,14 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT s.mu.Unlock() go transport.HandleStreams(func(s *Stream) { go h.handleStreamDelayRead(t, s) - }, func(ctx context.Context, method string) context.Context { - return ctx }) case pingpong: go transport.HandleStreams(func(s *Stream) { go h.handleStreamPingPong(t, s) - }, func(ctx context.Context, method string) context.Context { - return ctx }) default: go transport.HandleStreams(func(s *Stream) { go h.handleStream(t, s) - }, func(ctx context.Context, method string) context.Context { - return ctx }) } } diff --git a/server.go b/server.go index 244123c6c5a8..60cb5a27c50a 100644 --- a/server.go +++ b/server.go @@ -616,7 +616,7 @@ func (s *Server) serverWorker() { func (s *Server) handleSingleStream(data *serverWorkerData) { defer data.wg.Done() - s.handleStream(data.st, data.stream, s.traceInfo(data.st, data.stream)) + s.handleStream(data.st, data.stream) } // initServerWorkers creates worker goroutines and a channel to process incoming @@ -995,14 +995,8 @@ func (s *Server) serveStreams(st transport.ServerTransport) { } go func() { defer wg.Done() - s.handleStream(st, stream, s.traceInfo(st, stream)) + s.handleStream(st, stream) }() - }, func(ctx context.Context, method string) context.Context { - if !EnableTracing { - return ctx - } - tr := trace.New("grpc.Recv."+methodFamily(method), method) - return trace.NewContext(ctx, tr) }) wg.Wait() } @@ -1051,30 +1045,6 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.serveStreams(st) } -// traceInfo returns a traceInfo and associates it with stream, if tracing is enabled. -// If tracing is not enabled, it returns nil. -func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Stream) (trInfo *traceInfo) { - if !EnableTracing { - return nil - } - tr, ok := trace.FromContext(stream.Context()) - if !ok { - return nil - } - - trInfo = &traceInfo{ - tr: tr, - firstLine: firstLine{ - client: false, - remoteAddr: st.RemoteAddr(), - }, - } - if dl, ok := stream.Context().Deadline(); ok { - trInfo.firstLine.deadline = time.Until(dl) - } - return trInfo -} - func (s *Server) addConn(addr string, st transport.ServerTransport) bool { s.mu.Lock() defer s.mu.Unlock() @@ -1135,7 +1105,7 @@ func (s *Server) incrCallsFailed() { atomic.AddInt64(&s.czData.callsFailed, 1) } -func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg any, cp Compressor, opts *transport.Options, comp encoding.Compressor) error { +func (s *Server) sendResponse(ctx context.Context, t transport.ServerTransport, stream *transport.Stream, msg any, cp Compressor, opts *transport.Options, comp encoding.Compressor) error { data, err := encode(s.getCodec(stream.ContentSubtype()), msg) if err != nil { channelz.Error(logger, s.channelzID, "grpc: server failed to encode response: ", err) @@ -1154,7 +1124,7 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str err = t.Write(stream, hdr, payload, opts) if err == nil { for _, sh := range s.opts.statsHandlers { - sh.HandleRPC(stream.Context(), outPayload(false, msg, data, payload, time.Now())) + sh.HandleRPC(ctx, outPayload(false, msg, data, payload, time.Now())) } } return err @@ -1196,7 +1166,7 @@ func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info } } -func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, info *serviceInfo, md *MethodDesc, trInfo *traceInfo) (err error) { +func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTransport, stream *transport.Stream, info *serviceInfo, md *MethodDesc, trInfo *traceInfo) (err error) { shs := s.opts.statsHandlers if len(shs) != 0 || trInfo != nil || channelz.IsOn() { if channelz.IsOn() { @@ -1210,7 +1180,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. IsClientStream: false, IsServerStream: false, } - sh.HandleRPC(stream.Context(), statsBegin) + sh.HandleRPC(ctx, statsBegin) } if trInfo != nil { trInfo.tr.LazyLog(&trInfo.firstLine, false) @@ -1242,7 +1212,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if err != nil && err != io.EOF { end.Error = toRPCErr(err) } - sh.HandleRPC(stream.Context(), end) + sh.HandleRPC(ctx, end) } if channelz.IsOn() { @@ -1264,7 +1234,6 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } } if len(binlogs) != 0 { - ctx := stream.Context() md, _ := metadata.FromIncomingContext(ctx) logEntry := &binarylog.ClientHeader{ Header: md, @@ -1350,7 +1319,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err) } for _, sh := range shs { - sh.HandleRPC(stream.Context(), &stats.InPayload{ + sh.HandleRPC(ctx, &stats.InPayload{ RecvTime: time.Now(), Payload: v, Length: len(d), @@ -1364,7 +1333,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. Message: d, } for _, binlog := range binlogs { - binlog.Log(stream.Context(), cm) + binlog.Log(ctx, cm) } } if trInfo != nil { @@ -1372,7 +1341,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } return nil } - ctx := NewContextWithServerTransportStream(stream.Context(), stream) + ctx = NewContextWithServerTransportStream(ctx, stream) reply, appErr := md.Handler(info.serviceImpl, ctx, df, s.opts.unaryInt) if appErr != nil { appStatus, ok := status.FromError(appErr) @@ -1397,7 +1366,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. Header: h, } for _, binlog := range binlogs { - binlog.Log(stream.Context(), sh) + binlog.Log(ctx, sh) } } st := &binarylog.ServerTrailer{ @@ -1405,7 +1374,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. Err: appErr, } for _, binlog := range binlogs { - binlog.Log(stream.Context(), st) + binlog.Log(ctx, st) } } return appErr @@ -1420,7 +1389,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if stream.SendCompress() != sendCompressorName { comp = encoding.GetCompressor(stream.SendCompress()) } - if err := s.sendResponse(t, stream, reply, cp, opts, comp); err != nil { + if err := s.sendResponse(ctx, t, stream, reply, cp, opts, comp); err != nil { if err == io.EOF { // The entire stream is done (for unary RPC only). return err @@ -1447,8 +1416,8 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. Err: appErr, } for _, binlog := range binlogs { - binlog.Log(stream.Context(), sh) - binlog.Log(stream.Context(), st) + binlog.Log(ctx, sh) + binlog.Log(ctx, st) } } return err @@ -1462,8 +1431,8 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. Message: reply, } for _, binlog := range binlogs { - binlog.Log(stream.Context(), sh) - binlog.Log(stream.Context(), sm) + binlog.Log(ctx, sh) + binlog.Log(ctx, sm) } } if channelz.IsOn() { @@ -1481,7 +1450,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. Err: appErr, } for _, binlog := range binlogs { - binlog.Log(stream.Context(), st) + binlog.Log(ctx, st) } } return t.WriteStatus(stream, statusOK) @@ -1523,7 +1492,7 @@ func getChainStreamHandler(interceptors []StreamServerInterceptor, curr int, inf } } -func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, info *serviceInfo, sd *StreamDesc, trInfo *traceInfo) (err error) { +func (s *Server) processStreamingRPC(ctx context.Context, t transport.ServerTransport, stream *transport.Stream, info *serviceInfo, sd *StreamDesc, trInfo *traceInfo) (err error) { if channelz.IsOn() { s.incrCallsStarted() } @@ -1537,10 +1506,10 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp IsServerStream: sd.ServerStreams, } for _, sh := range shs { - sh.HandleRPC(stream.Context(), statsBegin) + sh.HandleRPC(ctx, statsBegin) } } - ctx := NewContextWithServerTransportStream(stream.Context(), stream) + ctx = NewContextWithServerTransportStream(ctx, stream) ss := &serverStream{ ctx: ctx, t: t, @@ -1576,7 +1545,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp end.Error = toRPCErr(err) } for _, sh := range shs { - sh.HandleRPC(stream.Context(), end) + sh.HandleRPC(ctx, end) } } @@ -1618,7 +1587,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp logEntry.PeerAddr = peer.Addr } for _, binlog := range ss.binlogs { - binlog.Log(stream.Context(), logEntry) + binlog.Log(ctx, logEntry) } } @@ -1696,7 +1665,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp Err: appErr, } for _, binlog := range ss.binlogs { - binlog.Log(stream.Context(), st) + binlog.Log(ctx, st) } } t.WriteStatus(ss.s, appStatus) @@ -1714,33 +1683,48 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp Err: appErr, } for _, binlog := range ss.binlogs { - binlog.Log(stream.Context(), st) + binlog.Log(ctx, st) } } return t.WriteStatus(ss.s, statusOK) } -func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream, trInfo *traceInfo) { +func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream) { + ctx := stream.Context() + var ti *traceInfo + if EnableTracing { + ti = &traceInfo{ + tr: trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method()), + firstLine: firstLine{ + client: false, + remoteAddr: t.RemoteAddr(), + }, + } + if dl, ok := ctx.Deadline(); ok { + ti.firstLine.deadline = time.Until(dl) + } + } + sm := stream.Method() if sm != "" && sm[0] == '/' { sm = sm[1:] } pos := strings.LastIndex(sm, "/") if pos == -1 { - if trInfo != nil { - trInfo.tr.LazyLog(&fmtStringer{"Malformed method name %q", []any{sm}}, true) - trInfo.tr.SetError() + if ti != nil { + ti.tr.LazyLog(&fmtStringer{"Malformed method name %q", []any{sm}}, true) + ti.tr.SetError() } errDesc := fmt.Sprintf("malformed method name: %q", stream.Method()) if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil { - if trInfo != nil { - trInfo.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true) - trInfo.tr.SetError() + if ti != nil { + ti.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true) + ti.tr.SetError() } channelz.Warningf(logger, s.channelzID, "grpc: Server.handleStream failed to write status: %v", err) } - if trInfo != nil { - trInfo.tr.Finish() + if ti != nil { + ti.tr.Finish() } return } @@ -1750,17 +1734,17 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str srv, knownService := s.services[service] if knownService { if md, ok := srv.methods[method]; ok { - s.processUnaryRPC(t, stream, srv, md, trInfo) + s.processUnaryRPC(ctx, t, stream, srv, md, ti) return } if sd, ok := srv.streams[method]; ok { - s.processStreamingRPC(t, stream, srv, sd, trInfo) + s.processStreamingRPC(ctx, t, stream, srv, sd, ti) return } } // Unknown service, or known server unknown method. if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil { - s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo) + s.processStreamingRPC(ctx, t, stream, nil, unknownDesc, ti) return } var errDesc string @@ -1769,19 +1753,19 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str } else { errDesc = fmt.Sprintf("unknown method %v for service %v", method, service) } - if trInfo != nil { - trInfo.tr.LazyPrintf("%s", errDesc) - trInfo.tr.SetError() + if ti != nil { + ti.tr.LazyPrintf("%s", errDesc) + ti.tr.SetError() } if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil { - if trInfo != nil { - trInfo.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true) - trInfo.tr.SetError() + if ti != nil { + ti.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true) + ti.tr.SetError() } channelz.Warningf(logger, s.channelzID, "grpc: Server.handleStream failed to write status: %v", err) } - if trInfo != nil { - trInfo.tr.Finish() + if ti != nil { + ti.tr.Finish() } } diff --git a/stats/stats_test.go b/stats/stats_test.go index 7dcb53491a26..903c3ed7a774 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -829,16 +829,6 @@ func checkServerStats(t *testing.T, got []*gotData, expect *expectedData, checkF t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs)) } - var rpcctx context.Context - for i := 0; i < len(got); i++ { - if _, ok := got[i].s.(stats.RPCStats); ok { - if rpcctx != nil && got[i].ctx != rpcctx { - t.Fatalf("got different contexts with stats %T", got[i].s) - } - rpcctx = got[i].ctx - } - } - for i, f := range checkFuncs { f(t, got[i], expect) } From 17cb68dd4d40a3475077f08c967464f20aa78778 Mon Sep 17 00:00:00 2001 From: Zach Reyes Date: Wed, 30 Aug 2023 14:15:08 -0400 Subject: [PATCH 2/2] Offline discussion, stick trace in context --- server.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server.go b/server.go index 60cb5a27c50a..a7b5532dac2d 100644 --- a/server.go +++ b/server.go @@ -1693,8 +1693,10 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str ctx := stream.Context() var ti *traceInfo if EnableTracing { + tr := trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method()) + ctx = trace.NewContext(ctx, tr) ti = &traceInfo{ - tr: trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method()), + tr: tr, firstLine: firstLine{ client: false, remoteAddr: t.RemoteAddr(),