diff --git a/.golangci.yml b/.golangci.yml index 5906654..03807cf 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -52,6 +52,7 @@ linters: - maintidx # covered by gocyclo - maligned # readability trumps efficient struct packing - nlreturn # generous whitespace violates house style + - nonamedreturns # named returns are fine; it's *bare* returns that are bad - nosnakecase # deprecated in https://github.com/golangci/golangci-lint/pull/3065 - scopelint # deprecated by author - structcheck # abandoned diff --git a/context.go b/context.go new file mode 100644 index 0000000..70ad4af --- /dev/null +++ b/context.go @@ -0,0 +1,25 @@ +// Copyright 2022-2023 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build go1.21 + +package otelconnect + +import "context" + +// afterFunc calls context.AfterFunc. Build tags ensure that this function +// is only compiled when the Go version is at least 1.21. +func afterFunc(ctx context.Context, f func()) (stop func() bool) { + return context.AfterFunc(ctx, f) +} diff --git a/context_legacy.go b/context_legacy.go new file mode 100644 index 0000000..93ea3a7 --- /dev/null +++ b/context_legacy.go @@ -0,0 +1,43 @@ +// Copyright 2022-2023 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !go1.21 + +package otelconnect + +import ( + "context" + "sync/atomic" +) + +// afterFunc is a simple imitation of context.AfterFunc from Go 1.21. +// It is not as efficient as the real implementation, but it is sufficient +// for our purposes. +func afterFunc(ctx context.Context, f func()) (stop func() bool) { + ctx, cancel := context.WithCancel(ctx) + var once atomic.Bool + go func() { + <-ctx.Done() + if once.CompareAndSwap(false, true) { + f() + } + }() + return func() bool { + didStop := once.CompareAndSwap(false, true) + if didStop { + cancel() + } + return didStop + } +} diff --git a/interceptor.go b/interceptor.go index 9ec8622..223354e 100644 --- a/interceptor.go +++ b/interceptor.go @@ -225,28 +225,34 @@ func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) conn )..., ) } + closeSpan := func() { + requestOnce.Do(setRequestAttributes) + // state.attributes is updated with the final error that was recorded. + // If error is nil a "success" is recorded on the span and on the final duration + // metric. The "rpc..status_code" is not defined for any other metrics for + // streams because the error only exists when finishing the stream. + if statusCode, ok := statusCodeAttribute(protocol, state.error); ok { + state.addAttributes(statusCode) + } + span.SetAttributes(state.attributes...) + span.SetAttributes(headerAttributes(protocol, responseKey, conn.ResponseHeader(), i.config.responseHeaderKeys)...) + span.SetStatus(clientSpanStatus(protocol, state.error)) + span.End() + instrumentation.requestsPerRPC.Record(ctx, state.sentCounter, + metric.WithAttributes(state.attributes...)) + instrumentation.responsesPerRPC.Record(ctx, state.receivedCounter, + metric.WithAttributes(state.attributes...)) + duration := i.config.now().Sub(requestStartTime).Milliseconds() + instrumentation.duration.Record(ctx, duration, + metric.WithAttributes(state.attributes...)) + } + stopCtxClose := afterFunc(ctx, closeSpan) return &streamingClientInterceptor{ StreamingClientConn: conn, onClose: func() { - requestOnce.Do(setRequestAttributes) - // state.attributes is updated with the final error that was recorded. - // If error is nil a "success" is recorded on the span and on the final duration - // metric. The "rpc..status_code" is not defined for any other metrics for - // streams because the error only exists when finishing the stream. - if statusCode, ok := statusCodeAttribute(protocol, state.error); ok { - state.addAttributes(statusCode) + if stopCtxClose() { + closeSpan() } - span.SetAttributes(state.attributes...) - span.SetAttributes(headerAttributes(protocol, responseKey, conn.ResponseHeader(), i.config.responseHeaderKeys)...) - span.SetStatus(clientSpanStatus(protocol, state.error)) - span.End() - instrumentation.requestsPerRPC.Record(ctx, state.sentCounter, - metric.WithAttributes(state.attributes...)) - instrumentation.responsesPerRPC.Record(ctx, state.receivedCounter, - metric.WithAttributes(state.attributes...)) - duration := i.config.now().Sub(requestStartTime).Milliseconds() - instrumentation.duration.Record(ctx, duration, - metric.WithAttributes(state.attributes...)) }, receive: func(msg any, conn connect.StreamingClientConn) error { return state.receive(ctx, msg, conn) diff --git a/payloadinterceptor.go b/payloadinterceptor.go index f734480..8e2f309 100644 --- a/payloadinterceptor.go +++ b/payloadinterceptor.go @@ -15,8 +15,6 @@ package otelconnect import ( - "sync" - connect "connectrpc.com/connect" ) @@ -26,11 +24,6 @@ type streamingClientInterceptor struct { receive func(any, connect.StreamingClientConn) error send func(any, connect.StreamingClientConn) error onClose func() - - mu sync.Mutex - requestClosed bool - responseClosed bool - onCloseCalled bool } func (s *streamingClientInterceptor) Receive(msg any) error { @@ -41,33 +34,9 @@ func (s *streamingClientInterceptor) Send(msg any) error { return s.send(msg, s.StreamingClientConn) } -func (s *streamingClientInterceptor) CloseRequest() error { - err := s.StreamingClientConn.CloseRequest() - s.mu.Lock() - s.requestClosed = true - shouldCall := s.responseClosed && !s.onCloseCalled - if shouldCall { - s.onCloseCalled = true - } - s.mu.Unlock() - if shouldCall { - s.onClose() - } - return err -} - func (s *streamingClientInterceptor) CloseResponse() error { err := s.StreamingClientConn.CloseResponse() - s.mu.Lock() - s.responseClosed = true - shouldCall := s.requestClosed && !s.onCloseCalled - if shouldCall { - s.onCloseCalled = true - } - s.mu.Unlock() - if shouldCall { - s.onClose() - } + s.onClose() return err }