Skip to content

Commit

Permalink
Remove *Request to fix streaming client context (#153)
Browse files Browse the repository at this point in the history
Context in client streams doesn't include the span due to the `*Request`
requiring information only included once the connection is created. This
PR removes the `*Request` in favor of simply providing the `connect.Spec`
information. This allows for correctly initializing the span so its included in
the context passed to the machinery to issue the call.

Backward-incompatible changes:
* Removes exported `*Request` object (breaking change)
* Changes option filters to use `connect.Spec` instead of `*Request`
(breaking change)
  • Loading branch information
emcfarlane authored Dec 19, 2023
1 parent 07e7942 commit 6ccd433
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 139 deletions.
20 changes: 10 additions & 10 deletions attributes.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,20 @@ import (
semconv "go.opentelemetry.io/otel/semconv/v1.21.0"
)

// AttributeFilter is used to filter attributes out based on the [Request] and [attribute.KeyValue].
// If the filter returns true the attribute will be kept else it will be removed.
// AttributeFilter must be safe to call concurrently.
type AttributeFilter func(*Request, attribute.KeyValue) bool
// AttributeFilter is used to filter attributes out based on the [connect.Spec]
// and [attribute.KeyValue]. If the filter returns true the attribute will be
// kept else it will be removed. AttributeFilter must be safe to call concurrently.
type AttributeFilter func(connect.Spec, attribute.KeyValue) bool

func (filter AttributeFilter) filter(request *Request, values ...attribute.KeyValue) []attribute.KeyValue {
func (filter AttributeFilter) filter(spec connect.Spec, values ...attribute.KeyValue) []attribute.KeyValue {
if filter == nil {
return values
}
// Assign a new slice of zero length with the same underlying
// array as the values slice. This avoids unnecessary memory allocations.
filteredValues := values[:0]
for _, attr := range values {
if filter(request, attr) {
if filter(spec, attr) {
filteredValues = append(filteredValues, attr)
}
}
Expand Down Expand Up @@ -71,13 +71,13 @@ func procedureAttributes(procedure string) []attribute.KeyValue {
return attrs
}

func requestAttributes(req *Request) []attribute.KeyValue {
func requestAttributes(spec connect.Spec, peer connect.Peer) []attribute.KeyValue {
var attrs []attribute.KeyValue
if addr := req.Peer.Addr; addr != "" {
if addr := peer.Addr; addr != "" {
attrs = append(attrs, addressAttributes(addr)...)
}
name := strings.TrimLeft(req.Spec.Procedure, "/")
protocol := protocolToSemConv(req.Peer.Protocol)
name := strings.TrimLeft(spec.Procedure, "/")
protocol := protocolToSemConv(peer.Protocol)
attrs = append(attrs, semconv.RPCSystemKey.String(protocol))
attrs = append(attrs, procedureAttributes(name)...)
return attrs
Expand Down
80 changes: 33 additions & 47 deletions interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,28 +91,23 @@ func (i *Interceptor) getInstruments(isClient bool) *instruments {
func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) {
requestStartTime := i.config.now()
req := &Request{
Spec: request.Spec(),
Peer: request.Peer(),
Header: request.Header(),
}
if i.config.filter != nil {
if !i.config.filter(ctx, req) {
if !i.config.filter(ctx, request.Spec()) {
return next(ctx, request)
}
}
attributeFilter := i.config.filterAttribute.filter
isClient := request.Spec().IsClient
name := strings.TrimLeft(request.Spec().Procedure, "/")
protocol := protocolToSemConv(request.Peer().Protocol)
attributes := attributeFilter(req, requestAttributes(req)...)
attributes := attributeFilter(request.Spec(), requestAttributes(request.Spec(), request.Peer())...)
instrumentation := i.getInstruments(isClient)
carrier := propagation.HeaderCarrier(request.Header())
spanKind := trace.SpanKindClient
requestSpan, responseSpan := semconv.MessageTypeSent, semconv.MessageTypeReceived
traceOpts := []trace.SpanStartOption{
trace.WithAttributes(attributes...),
trace.WithAttributes(headerAttributes(protocol, requestKey, req.Header, i.config.requestHeaderKeys)...),
trace.WithAttributes(headerAttributes(protocol, requestKey, request.Header(), i.config.requestHeaderKeys)...),
}
if !isClient {
spanKind = trace.SpanKindServer
Expand Down Expand Up @@ -174,7 +169,7 @@ func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
),
)
}
attributes = attributeFilter(req, attributes...)
attributes = attributeFilter(request.Spec(), attributes...)
if isClient {
span.SetStatus(clientSpanStatus(protocol, err))
} else {
Expand All @@ -193,51 +188,47 @@ func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
// WrapStreamingClient implements otel tracing and metrics for streaming connect clients.
func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
requestStartTime := i.config.now()
conn := next(ctx, spec)
instrumentation := i.getInstruments(spec.IsClient)
req := &Request{
Spec: conn.Spec(),
Peer: conn.Peer(),
Header: conn.RequestHeader(),
}
if i.config.filter != nil {
if !i.config.filter(ctx, req) {
return conn
if !i.config.filter(ctx, spec) {
return next(ctx, spec)
}
}
name := strings.TrimLeft(conn.Spec().Procedure, "/")
protocol := protocolToSemConv(conn.Peer().Protocol)
requestStartTime := i.config.now()
name := strings.TrimLeft(spec.Procedure, "/")
ctx, span := i.config.tracer.Start(
ctx,
name,
trace.WithSpanKind(trace.SpanKindClient),
)
conn := next(ctx, spec)
instrumentation := i.getInstruments(spec.IsClient)
// inject the newly created span into the carrier
carrier := propagation.HeaderCarrier(conn.RequestHeader())
i.config.propagator.Inject(ctx, carrier)
state := newStreamingState(
req,
spec,
conn.Peer(),
i.config.filterAttribute,
i.config.omitTraceEvents,
requestAttributes(req),
instrumentation.responseSize,
instrumentation.requestSize,
)
var span trace.Span
var createSpanOnce sync.Once
createSpan := func() {
ctx, span = i.config.tracer.Start(
ctx,
name,
trace.WithSpanKind(trace.SpanKindClient),
trace.WithAttributes(state.attributes...),
trace.WithAttributes(headerAttributes(
protocol := protocolToSemConv(conn.Peer().Protocol)
var requestOnce sync.Once
setRequestAttributes := func() {
span.SetAttributes(
headerAttributes(
protocol,
requestKey,
conn.RequestHeader(),
i.config.requestHeaderKeys)...),
i.config.requestHeaderKeys,
)...,
)
// inject the newly created span into the carrier
carrier := propagation.HeaderCarrier(conn.RequestHeader())
i.config.propagator.Inject(ctx, carrier)
}
return &streamingClientInterceptor{
StreamingClientConn: conn,
onClose: func() {
createSpanOnce.Do(createSpan)
requestOnce.Do(setRequestAttributes)
// state.attributes is updated with the final error that was recorded.
// If error is nil a "success" is recorded on the span and on the final duration
// metric. The "rpc.<protocol>.status_code" is not defined for any other metrics for
Expand All @@ -261,7 +252,7 @@ func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) conn
return state.receive(ctx, msg, conn)
},
send: func(msg any, conn connect.StreamingClientConn) error {
createSpanOnce.Do(createSpan)
requestOnce.Do(setRequestAttributes)
return state.send(ctx, msg, conn)
},
}
Expand All @@ -274,23 +265,18 @@ func (i *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) co
requestStartTime := i.config.now()
isClient := conn.Spec().IsClient
instrumentation := i.getInstruments(isClient)
req := &Request{
Spec: conn.Spec(),
Peer: conn.Peer(),
Header: conn.RequestHeader(),
}
if i.config.filter != nil {
if !i.config.filter(ctx, req) {
if !i.config.filter(ctx, conn.Spec()) {
return next(ctx, conn)
}
}
name := strings.TrimLeft(conn.Spec().Procedure, "/")
protocol := protocolToSemConv(conn.Peer().Protocol)
state := newStreamingState(
req,
conn.Spec(),
conn.Peer(),
i.config.filterAttribute,
i.config.omitTraceEvents,
requestAttributes(req),
instrumentation.requestSize,
instrumentation.responseSize,
)
Expand All @@ -299,7 +285,7 @@ func (i *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) co
traceOpts := []trace.SpanStartOption{
trace.WithSpanKind(trace.SpanKindServer),
trace.WithAttributes(state.attributes...),
trace.WithAttributes(headerAttributes(protocol, requestKey, req.Header, i.config.requestHeaderKeys)...),
trace.WithAttributes(headerAttributes(protocol, requestKey, conn.RequestHeader(), i.config.requestHeaderKeys)...),
}
if !trace.SpanContextFromContext(ctx).IsValid() {
ctx = i.config.propagator.Extract(ctx, carrier)
Expand Down
92 changes: 33 additions & 59 deletions interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1016,7 +1016,7 @@ func TestClientHandlerOpts(t *testing.T) {
clientTraceProvider := trace.NewTracerProvider(trace.WithSpanProcessor(clientSpanRecorder))
serverInterceptor, err := NewInterceptor(
WithTracerProvider(serverTraceProvider),
WithFilter(func(ctx context.Context, request *Request) bool {
WithFilter(func(ctx context.Context, spec connect.Spec) bool {
return false
}),
)
Expand Down Expand Up @@ -1073,7 +1073,7 @@ func TestBasicFilter(t *testing.T) {
headerKey, headerVal := "Some-Header", "foobar"
spanRecorder := tracetest.NewSpanRecorder()
traceProvider := trace.NewTracerProvider(trace.WithSpanProcessor(spanRecorder))
serverInterceptor, err := NewInterceptor(WithTracerProvider(traceProvider), WithFilter(func(ctx context.Context, request *Request) bool {
serverInterceptor, err := NewInterceptor(WithTracerProvider(traceProvider), WithFilter(func(ctx context.Context, spec connect.Spec) bool {
return false
}))
require.NoError(t, err)
Expand All @@ -1091,58 +1091,6 @@ func TestBasicFilter(t *testing.T) {
assertSpans(t, []wantSpans{}, spanRecorder.Ended())
}

func TestFilterHeader(t *testing.T) {
t.Parallel()
headerKey, headerVal := "Some-Header", "foobar"
spanRecorder := tracetest.NewSpanRecorder()
traceProvider := trace.NewTracerProvider(trace.WithSpanProcessor(spanRecorder))
serverInterceptor, err := NewInterceptor(WithTracerProvider(traceProvider), WithFilter(func(ctx context.Context, request *Request) bool {
return request.Header.Get(headerKey) == headerVal
}))
require.NoError(t, err)
pingClient, host, port := startServer([]connect.HandlerOption{
connect.WithInterceptors(serverInterceptor),
}, nil, okayPingServer())
req := requestOfSize(1, 0)
req.Header().Set(headerKey, headerVal)
if _, err := pingClient.Ping(context.Background(), req); err != nil {
t.Errorf(err.Error())
}
if _, err := pingClient.Ping(context.Background(), requestOfSize(1, 0)); err != nil {
t.Errorf(err.Error())
}
assertSpans(t, []wantSpans{
{
spanName: pingv1connect.PingServiceName + "/" + PingMethod,
events: []trace.Event{
{
Name: messageKey,
Attributes: []attribute.KeyValue{
semconv.MessageTypeReceived,
semconv.MessageIDKey.Int(1),
semconv.MessageUncompressedSizeKey.Int(2),
},
},
{
Name: messageKey,
Attributes: []attribute.KeyValue{
semconv.MessageTypeSent,
semconv.MessageIDKey.Int(1),
semconv.MessageUncompressedSizeKey.Int(2),
},
},
},
attrs: []attribute.KeyValue{
semconv.NetPeerNameKey.String(host),
semconv.NetPeerPortKey.Int(port),
semconv.RPCSystemKey.String(bufConnect),
semconv.RPCServiceKey.String(pingv1connect.PingServiceName),
semconv.RPCMethodKey.String(PingMethod),
},
},
}, spanRecorder.Ended())
}

func TestHeaderAttribute(t *testing.T) {
t.Parallel()
var propagator propagation.TraceContext
Expand Down Expand Up @@ -1408,9 +1356,9 @@ func TestUnaryPropagation(t *testing.T) {
require.NoError(t, err)
client, _, _ := startServer(
[]connect.HandlerOption{
connect.WithInterceptors(serverInterceptor),
connect.WithInterceptors(serverInterceptor, assertSpanInterceptor{t: t}),
}, []connect.ClientOption{
connect.WithInterceptors(clientInterceptor),
connect.WithInterceptors(clientInterceptor, assertSpanInterceptor{t: t}),
}, okayPingServer())
_, err = client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{Id: 1}))
require.NoError(t, err)
Expand Down Expand Up @@ -1643,7 +1591,7 @@ func TestStreamingClientPropagation(t *testing.T) {
)
require.NoError(t, err)
client, _, _ := startServer(nil, []connect.ClientOption{
connect.WithInterceptors(clientInterceptor),
connect.WithInterceptors(clientInterceptor, assertSpanInterceptor{t: t}),
}, &pluggablePingServer{pingStream: assertTraceParent},
)
stream := client.PingStream(context.Background())
Expand All @@ -1662,7 +1610,7 @@ func TestStreamingHandlerTracing(t *testing.T) {
serverInterceptor, err := NewInterceptor(WithTracerProvider(traceProvider))
require.NoError(t, err)
pingClient, host, port := startServer([]connect.HandlerOption{
connect.WithInterceptors(serverInterceptor),
connect.WithInterceptors(serverInterceptor, assertSpanInterceptor{t: t}),
}, nil, okayPingServer())
stream := pingClient.PingStream(context.Background())

Expand Down Expand Up @@ -1767,7 +1715,7 @@ func TestWithAttributeFilter(t *testing.T) {
traceProvider := trace.NewTracerProvider(trace.WithSpanProcessor(spanRecorder))
clientInterceptor, err := NewInterceptor(
WithTracerProvider(traceProvider),
WithAttributeFilter(func(_ *Request, value attribute.KeyValue) bool {
WithAttributeFilter(func(_ connect.Spec, value attribute.KeyValue) bool {
if value.Key == semconv.MessageIDKey {
return false
}
Expand Down Expand Up @@ -2230,3 +2178,29 @@ func serverSpanStatusTestCases() []serverSpanStatusTestCase {
{connectCode: connect.CodeUnauthenticated, wantServerSpanCode: codes.Unset, wantServerSpanDescription: ""},
}
}

type assertSpanInterceptor struct{ t testing.TB }

func (i assertSpanInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) {
i.assertSpanContext(ctx)
return next(ctx, request)
}
}
func (i assertSpanInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
i.assertSpanContext(ctx)
return next(ctx, spec)
}
}
func (i assertSpanInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
i.assertSpanContext(ctx)
return next(ctx, conn)
}
}
func (i assertSpanInterceptor) assertSpanContext(ctx context.Context) {
if !traceapi.SpanContextFromContext(ctx).IsValid() {
i.t.Error("invalid span context")
}
}
9 changes: 5 additions & 4 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"net/http"

connect "connectrpc.com/connect"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/metric/noop"
Expand Down Expand Up @@ -54,7 +55,7 @@ func WithTracerProvider(provider trace.TracerProvider) Option {

// WithFilter configures the instrumentation to emit traces and metrics only
// when the filter function returns true. Filter functions must be safe to call concurrently.
func WithFilter(filter func(context.Context, *Request) bool) Option {
func WithFilter(filter func(context.Context, connect.Spec) bool) Option {
return &filterOption{filter}
}

Expand All @@ -79,8 +80,8 @@ func WithAttributeFilter(filter AttributeFilter) Option {
// high-cardinality data; this option significantly reduces cardinality in most
// environments.
func WithoutServerPeerAttributes() Option {
return WithAttributeFilter(func(request *Request, value attribute.KeyValue) bool {
if request.Spec.IsClient {
return WithAttributeFilter(func(spec connect.Spec, value attribute.KeyValue) bool {
if spec.IsClient {
return true
}
if value.Key == semconv.NetPeerPortKey {
Expand Down Expand Up @@ -158,7 +159,7 @@ func (o *tracerProviderOption) apply(c *config) {
}

type filterOption struct {
filter func(context.Context, *Request) bool
filter func(context.Context, connect.Spec) bool
}

func (o *filterOption) apply(c *config) {
Expand Down
Loading

0 comments on commit 6ccd433

Please sign in to comment.