Skip to content

Commit

Permalink
fix(grpc): aio grpc client interceptor incorrect metadata handling
Browse files Browse the repository at this point in the history
  • Loading branch information
legau committed Mar 20, 2024
1 parent 70d9b3d commit 3c1c5a9
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@

import functools
import logging
from collections import OrderedDict

import grpc
from grpc.aio import ClientCallDetails
from grpc.aio import ClientCallDetails, Metadata

from opentelemetry.instrumentation.grpc._client import (
OpenTelemetryClientInterceptor,
Expand Down Expand Up @@ -55,20 +54,19 @@ def callback(call):

class _BaseAioClientInterceptor(OpenTelemetryClientInterceptor):
@staticmethod
def propagate_trace_in_details(client_call_details):
def propagate_trace_in_details(client_call_details: ClientCallDetails):
metadata = client_call_details.metadata
if not metadata:
mutable_metadata = OrderedDict()
mutable_metadata = Metadata()
else:
mutable_metadata = OrderedDict(metadata)
mutable_metadata = Metadata(*tuple(metadata))

inject(mutable_metadata, setter=_carrier_setter)
metadata = tuple(mutable_metadata.items())

return ClientCallDetails(
client_call_details.method,
client_call_details.timeout,
metadata,
mutable_metadata,
client_call_details.credentials,
client_call_details.wait_for_ready,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def simple_method(stub, error=False):
request = Request(
client_id=CLIENT_ID, request_data="error" if error else "data"
)
return await stub.SimpleMethod(request)
return await stub.SimpleMethod(request, metadata=(("key", "value"),))


async def client_streaming_method(stub, error=False):
Expand All @@ -41,7 +41,7 @@ def server_streaming_method(stub, error=False):
client_id=CLIENT_ID, request_data="error" if error else "data"
)

return stub.ServerStreamingMethod(request)
return stub.ServerStreamingMethod(request, metadata=(("key", "value"),))


def bidirectional_streaming_method(stub, error=False):
Expand All @@ -53,4 +53,6 @@ def request_messages():
)
yield request

return stub.BidirectionalStreamingMethod(request_messages())
return stub.BidirectionalStreamingMethod(
request_messages(), metadata=(("key", "value"),)
)
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ def simple_method(stub, error=False):
request = Request(
client_id=CLIENT_ID, request_data="error" if error else "data"
)
stub.SimpleMethod(request)
stub.SimpleMethod(request, metadata=(("key", "value"),))


def simple_method_future(stub, error=False):
request = Request(
client_id=CLIENT_ID, request_data="error" if error else "data"
)
return stub.SimpleMethod.future(request)
return stub.SimpleMethod.future(request, metadata=(("key", "value"),))


def client_streaming_method(stub, error=False):
Expand All @@ -40,14 +40,18 @@ def request_messages():
)
yield request

stub.ClientStreamingMethod(request_messages())
stub.ClientStreamingMethod(
request_messages(), metadata=(("key", "value"),)
)


def server_streaming_method(stub, error=False):
request = Request(
client_id=CLIENT_ID, request_data="error" if error else "data"
)
response_iterator = stub.ServerStreamingMethod(request)
response_iterator = stub.ServerStreamingMethod(
request, metadata=(("key", "value"),)
)
list(response_iterator)


Expand All @@ -59,6 +63,8 @@ def request_messages():
)
yield request

response_iterator = stub.BidirectionalStreamingMethod(request_messages())
response_iterator = stub.BidirectionalStreamingMethod(
request_messages(), metadata=(("key", "value"),)
)

list(response_iterator)
Original file line number Diff line number Diff line change
Expand Up @@ -305,11 +305,10 @@ async def test_client_interceptor_trace_context_propagation(self):
await simple_method(stub)

metadata = recording_interceptor.recorded_details.metadata
assert len(metadata) == 2
assert metadata[0][0] == "mock-traceid"
assert metadata[0][1] == "0"
assert metadata[1][0] == "mock-spanid"
assert metadata[1][1] == "0"
assert len(metadata) == 3
assert metadata.get_all("key") == ["value"]
assert metadata.get_all("mock-traceid") == ["0"]
assert metadata.get_all("mock-spanid") == ["0"]
finally:
set_global_textmap(previous_propagator)

Expand Down

0 comments on commit 3c1c5a9

Please sign in to comment.