Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove *Request to fix streaming client context #153

Merged
merged 4 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions attributes.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,20 @@ import (
semconv "go.opentelemetry.io/otel/semconv/v1.21.0"
)

// AttributeFilter is used to filter attributes out based on the [Request] and [attribute.KeyValue].
// If the filter returns true the attribute will be kept else it will be removed.
// AttributeFilter must be safe to call concurrently.
type AttributeFilter func(*Request, attribute.KeyValue) bool
// AttributeFilter is used to filter attributes out based on the [connect.Spec]
// and [attribute.KeyValue]. If the filter returns true the attribute will be
// kept else it will be removed. AttributeFilter must be safe to call concurrently.
type AttributeFilter func(connect.Spec, attribute.KeyValue) bool

func (filter AttributeFilter) filter(request *Request, values ...attribute.KeyValue) []attribute.KeyValue {
func (filter AttributeFilter) filter(spec connect.Spec, values ...attribute.KeyValue) []attribute.KeyValue {
if filter == nil {
return values
}
// Assign a new slice of zero length with the same underlying
// array as the values slice. This avoids unnecessary memory allocations.
filteredValues := values[:0]
for _, attr := range values {
if filter(request, attr) {
if filter(spec, attr) {
filteredValues = append(filteredValues, attr)
}
}
Expand Down Expand Up @@ -71,13 +71,13 @@ func procedureAttributes(procedure string) []attribute.KeyValue {
return attrs
}

func requestAttributes(req *Request) []attribute.KeyValue {
func requestAttributes(spec connect.Spec, peer connect.Peer) []attribute.KeyValue {
var attrs []attribute.KeyValue
if addr := req.Peer.Addr; addr != "" {
if addr := peer.Addr; addr != "" {
attrs = append(attrs, addressAttributes(addr)...)
}
name := strings.TrimLeft(req.Spec.Procedure, "/")
protocol := protocolToSemConv(req.Peer.Protocol)
name := strings.TrimLeft(spec.Procedure, "/")
protocol := protocolToSemConv(peer.Protocol)
attrs = append(attrs, semconv.RPCSystemKey.String(protocol))
attrs = append(attrs, procedureAttributes(name)...)
return attrs
Expand Down
80 changes: 33 additions & 47 deletions interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,28 +91,23 @@ func (i *Interceptor) getInstruments(isClient bool) *instruments {
func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) {
requestStartTime := i.config.now()
req := &Request{
Spec: request.Spec(),
Peer: request.Peer(),
Header: request.Header(),
}
if i.config.filter != nil {
if !i.config.filter(ctx, req) {
if !i.config.filter(ctx, request.Spec()) {
return next(ctx, request)
}
}
attributeFilter := i.config.filterAttribute.filter
isClient := request.Spec().IsClient
name := strings.TrimLeft(request.Spec().Procedure, "/")
protocol := protocolToSemConv(request.Peer().Protocol)
attributes := attributeFilter(req, requestAttributes(req)...)
attributes := attributeFilter(request.Spec(), requestAttributes(request.Spec(), request.Peer())...)
instrumentation := i.getInstruments(isClient)
carrier := propagation.HeaderCarrier(request.Header())
spanKind := trace.SpanKindClient
requestSpan, responseSpan := semconv.MessageTypeSent, semconv.MessageTypeReceived
traceOpts := []trace.SpanStartOption{
trace.WithAttributes(attributes...),
trace.WithAttributes(headerAttributes(protocol, requestKey, req.Header, i.config.requestHeaderKeys)...),
trace.WithAttributes(headerAttributes(protocol, requestKey, request.Header(), i.config.requestHeaderKeys)...),
}
if !isClient {
spanKind = trace.SpanKindServer
Expand Down Expand Up @@ -174,7 +169,7 @@ func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
),
)
}
attributes = attributeFilter(req, attributes...)
attributes = attributeFilter(request.Spec(), attributes...)
if isClient {
span.SetStatus(clientSpanStatus(protocol, err))
} else {
Expand All @@ -193,51 +188,47 @@ func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
// WrapStreamingClient implements otel tracing and metrics for streaming connect clients.
func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
requestStartTime := i.config.now()
conn := next(ctx, spec)
instrumentation := i.getInstruments(spec.IsClient)
req := &Request{
Spec: conn.Spec(),
Peer: conn.Peer(),
Header: conn.RequestHeader(),
}
if i.config.filter != nil {
if !i.config.filter(ctx, req) {
return conn
if !i.config.filter(ctx, spec) {
return next(ctx, spec)
}
}
name := strings.TrimLeft(conn.Spec().Procedure, "/")
protocol := protocolToSemConv(conn.Peer().Protocol)
requestStartTime := i.config.now()
name := strings.TrimLeft(spec.Procedure, "/")
ctx, span := i.config.tracer.Start(
ctx,
name,
trace.WithSpanKind(trace.SpanKindClient),
)
conn := next(ctx, spec)
instrumentation := i.getInstruments(spec.IsClient)
// inject the newly created span into the carrier
carrier := propagation.HeaderCarrier(conn.RequestHeader())
i.config.propagator.Inject(ctx, carrier)
state := newStreamingState(
req,
spec,
conn.Peer(),
i.config.filterAttribute,
i.config.omitTraceEvents,
requestAttributes(req),
instrumentation.responseSize,
instrumentation.requestSize,
)
var span trace.Span
var createSpanOnce sync.Once
createSpan := func() {
ctx, span = i.config.tracer.Start(
ctx,
name,
trace.WithSpanKind(trace.SpanKindClient),
trace.WithAttributes(state.attributes...),
trace.WithAttributes(headerAttributes(
protocol := protocolToSemConv(conn.Peer().Protocol)
var requestOnce sync.Once
setRequestAttributes := func() {
span.SetAttributes(
headerAttributes(
protocol,
requestKey,
conn.RequestHeader(),
i.config.requestHeaderKeys)...),
i.config.requestHeaderKeys,
)...,
)
// inject the newly created span into the carrier
carrier := propagation.HeaderCarrier(conn.RequestHeader())
i.config.propagator.Inject(ctx, carrier)
}
return &streamingClientInterceptor{
StreamingClientConn: conn,
onClose: func() {
createSpanOnce.Do(createSpan)
requestOnce.Do(setRequestAttributes)
// state.attributes is updated with the final error that was recorded.
// If error is nil a "success" is recorded on the span and on the final duration
// metric. The "rpc.<protocol>.status_code" is not defined for any other metrics for
Expand All @@ -261,7 +252,7 @@ func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) conn
return state.receive(ctx, msg, conn)
},
send: func(msg any, conn connect.StreamingClientConn) error {
createSpanOnce.Do(createSpan)
requestOnce.Do(setRequestAttributes)
return state.send(ctx, msg, conn)
},
}
Expand All @@ -274,23 +265,18 @@ func (i *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) co
requestStartTime := i.config.now()
isClient := conn.Spec().IsClient
instrumentation := i.getInstruments(isClient)
req := &Request{
Spec: conn.Spec(),
Peer: conn.Peer(),
Header: conn.RequestHeader(),
}
if i.config.filter != nil {
if !i.config.filter(ctx, req) {
if !i.config.filter(ctx, conn.Spec()) {
return next(ctx, conn)
}
}
name := strings.TrimLeft(conn.Spec().Procedure, "/")
protocol := protocolToSemConv(conn.Peer().Protocol)
state := newStreamingState(
req,
conn.Spec(),
conn.Peer(),
i.config.filterAttribute,
i.config.omitTraceEvents,
requestAttributes(req),
instrumentation.requestSize,
instrumentation.responseSize,
)
Expand All @@ -299,7 +285,7 @@ func (i *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) co
traceOpts := []trace.SpanStartOption{
trace.WithSpanKind(trace.SpanKindServer),
trace.WithAttributes(state.attributes...),
trace.WithAttributes(headerAttributes(protocol, requestKey, req.Header, i.config.requestHeaderKeys)...),
trace.WithAttributes(headerAttributes(protocol, requestKey, conn.RequestHeader(), i.config.requestHeaderKeys)...),
}
if !trace.SpanContextFromContext(ctx).IsValid() {
ctx = i.config.propagator.Extract(ctx, carrier)
Expand Down
92 changes: 33 additions & 59 deletions interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1016,7 +1016,7 @@ func TestClientHandlerOpts(t *testing.T) {
clientTraceProvider := trace.NewTracerProvider(trace.WithSpanProcessor(clientSpanRecorder))
serverInterceptor, err := NewInterceptor(
WithTracerProvider(serverTraceProvider),
WithFilter(func(ctx context.Context, request *Request) bool {
WithFilter(func(ctx context.Context, spec connect.Spec) bool {
return false
}),
)
Expand Down Expand Up @@ -1073,7 +1073,7 @@ func TestBasicFilter(t *testing.T) {
headerKey, headerVal := "Some-Header", "foobar"
spanRecorder := tracetest.NewSpanRecorder()
traceProvider := trace.NewTracerProvider(trace.WithSpanProcessor(spanRecorder))
serverInterceptor, err := NewInterceptor(WithTracerProvider(traceProvider), WithFilter(func(ctx context.Context, request *Request) bool {
serverInterceptor, err := NewInterceptor(WithTracerProvider(traceProvider), WithFilter(func(ctx context.Context, spec connect.Spec) bool {
return false
}))
require.NoError(t, err)
Expand All @@ -1091,58 +1091,6 @@ func TestBasicFilter(t *testing.T) {
assertSpans(t, []wantSpans{}, spanRecorder.Ended())
}

func TestFilterHeader(t *testing.T) {
t.Parallel()
headerKey, headerVal := "Some-Header", "foobar"
spanRecorder := tracetest.NewSpanRecorder()
traceProvider := trace.NewTracerProvider(trace.WithSpanProcessor(spanRecorder))
serverInterceptor, err := NewInterceptor(WithTracerProvider(traceProvider), WithFilter(func(ctx context.Context, request *Request) bool {
return request.Header.Get(headerKey) == headerVal
}))
require.NoError(t, err)
pingClient, host, port := startServer([]connect.HandlerOption{
connect.WithInterceptors(serverInterceptor),
}, nil, okayPingServer())
req := requestOfSize(1, 0)
req.Header().Set(headerKey, headerVal)
if _, err := pingClient.Ping(context.Background(), req); err != nil {
t.Errorf(err.Error())
}
if _, err := pingClient.Ping(context.Background(), requestOfSize(1, 0)); err != nil {
t.Errorf(err.Error())
}
assertSpans(t, []wantSpans{
{
spanName: pingv1connect.PingServiceName + "/" + PingMethod,
events: []trace.Event{
{
Name: messageKey,
Attributes: []attribute.KeyValue{
semconv.MessageTypeReceived,
semconv.MessageIDKey.Int(1),
semconv.MessageUncompressedSizeKey.Int(2),
},
},
{
Name: messageKey,
Attributes: []attribute.KeyValue{
semconv.MessageTypeSent,
semconv.MessageIDKey.Int(1),
semconv.MessageUncompressedSizeKey.Int(2),
},
},
},
attrs: []attribute.KeyValue{
semconv.NetPeerNameKey.String(host),
semconv.NetPeerPortKey.Int(port),
semconv.RPCSystemKey.String(bufConnect),
semconv.RPCServiceKey.String(pingv1connect.PingServiceName),
semconv.RPCMethodKey.String(PingMethod),
},
},
}, spanRecorder.Ended())
}

func TestHeaderAttribute(t *testing.T) {
t.Parallel()
var propagator propagation.TraceContext
Expand Down Expand Up @@ -1408,9 +1356,9 @@ func TestUnaryPropagation(t *testing.T) {
require.NoError(t, err)
client, _, _ := startServer(
[]connect.HandlerOption{
connect.WithInterceptors(serverInterceptor),
connect.WithInterceptors(serverInterceptor, assertSpanInterceptor{t: t}),
}, []connect.ClientOption{
connect.WithInterceptors(clientInterceptor),
connect.WithInterceptors(clientInterceptor, assertSpanInterceptor{t: t}),
}, okayPingServer())
_, err = client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{Id: 1}))
require.NoError(t, err)
Expand Down Expand Up @@ -1643,7 +1591,7 @@ func TestStreamingClientPropagation(t *testing.T) {
)
require.NoError(t, err)
client, _, _ := startServer(nil, []connect.ClientOption{
connect.WithInterceptors(clientInterceptor),
connect.WithInterceptors(clientInterceptor, assertSpanInterceptor{t: t}),
}, &pluggablePingServer{pingStream: assertTraceParent},
)
stream := client.PingStream(context.Background())
Expand All @@ -1662,7 +1610,7 @@ func TestStreamingHandlerTracing(t *testing.T) {
serverInterceptor, err := NewInterceptor(WithTracerProvider(traceProvider))
require.NoError(t, err)
pingClient, host, port := startServer([]connect.HandlerOption{
connect.WithInterceptors(serverInterceptor),
connect.WithInterceptors(serverInterceptor, assertSpanInterceptor{t: t}),
}, nil, okayPingServer())
stream := pingClient.PingStream(context.Background())

Expand Down Expand Up @@ -1767,7 +1715,7 @@ func TestWithAttributeFilter(t *testing.T) {
traceProvider := trace.NewTracerProvider(trace.WithSpanProcessor(spanRecorder))
clientInterceptor, err := NewInterceptor(
WithTracerProvider(traceProvider),
WithAttributeFilter(func(_ *Request, value attribute.KeyValue) bool {
WithAttributeFilter(func(_ connect.Spec, value attribute.KeyValue) bool {
if value.Key == semconv.MessageIDKey {
return false
}
Expand Down Expand Up @@ -2230,3 +2178,29 @@ func serverSpanStatusTestCases() []serverSpanStatusTestCase {
{connectCode: connect.CodeUnauthenticated, wantServerSpanCode: codes.Unset, wantServerSpanDescription: ""},
}
}

type assertSpanInterceptor struct{ t testing.TB }

func (i assertSpanInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) {
i.assertSpanContext(ctx)
return next(ctx, request)
}
}
func (i assertSpanInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
i.assertSpanContext(ctx)
return next(ctx, spec)
}
}
func (i assertSpanInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
i.assertSpanContext(ctx)
return next(ctx, conn)
}
}
func (i assertSpanInterceptor) assertSpanContext(ctx context.Context) {
if !traceapi.SpanContextFromContext(ctx).IsValid() {
i.t.Error("invalid span context")
}
}
9 changes: 5 additions & 4 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"net/http"

connect "connectrpc.com/connect"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/metric/noop"
Expand Down Expand Up @@ -54,7 +55,7 @@ func WithTracerProvider(provider trace.TracerProvider) Option {

// WithFilter configures the instrumentation to emit traces and metrics only
// when the filter function returns true. Filter functions must be safe to call concurrently.
func WithFilter(filter func(context.Context, *Request) bool) Option {
func WithFilter(filter func(context.Context, connect.Spec) bool) Option {
return &filterOption{filter}
}

Expand All @@ -79,8 +80,8 @@ func WithAttributeFilter(filter AttributeFilter) Option {
// high-cardinality data; this option significantly reduces cardinality in most
// environments.
func WithoutServerPeerAttributes() Option {
return WithAttributeFilter(func(request *Request, value attribute.KeyValue) bool {
if request.Spec.IsClient {
return WithAttributeFilter(func(spec connect.Spec, value attribute.KeyValue) bool {
if spec.IsClient {
return true
}
if value.Key == semconv.NetPeerPortKey {
Expand Down Expand Up @@ -158,7 +159,7 @@ func (o *tracerProviderOption) apply(c *config) {
}

type filterOption struct {
filter func(context.Context, *Request) bool
filter func(context.Context, connect.Spec) bool
}

func (o *filterOption) apply(c *config) {
Expand Down
Loading