Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix arity of context.abort for AIO RPCs #2066

Merged
merged 9 commits into from
Nov 28, 2023
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- `opentelemetry-instrumentation` Added Otel semantic convention opt-in mechanism
([#1987](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1987))
- `opentelemetry-instrument-grpc` Fix arity of context.abort for AIO RPCs
cookiefission marked this conversation as resolved.
Show resolved Hide resolved
([#2066](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2066))

## Version 1.21.0/0.42b0 (2023-11-01)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,133 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import grpc
import grpc.aio

from ._server import (
OpenTelemetryServerInterceptor,
_OpenTelemetryServicerContext,
_wrap_rpc_behavior,
)

from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace.status import Status, StatusCode

# pylint:disable=abstract-method
class _OpenTelemetryAioServicerContext(grpc.aio.ServicerContext):
cookiefission marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, servicer_context, active_span):
self._servicer_context = servicer_context
self._active_span = active_span
self._code = grpc.StatusCode.OK
self._details = None
super().__init__()

def __getattr__(self, attr):
return getattr(self._servicer_context, attr)

async def read(self):
return await self._servicer_context.read()

async def write(self, message):
return await self._servicer_context.write(message)

def is_active(self, *args, **kwargs):
return self._servicer_context.is_active(*args, **kwargs)

def time_remaining(self, *args, **kwargs):
return self._servicer_context.time_remaining(*args, **kwargs)

def cancel(self, *args, **kwargs):
return self._servicer_context.cancel(*args, **kwargs)

def add_callback(self, *args, **kwargs):
return self._servicer_context.add_callback(*args, **kwargs)

def disable_next_message_compression(self):
return self._service_context.disable_next_message_compression()

def invocation_metadata(self, *args, **kwargs):
return self._servicer_context.invocation_metadata(*args, **kwargs)

def peer(self):
return self._servicer_context.peer()

def peer_identities(self):
return self._servicer_context.peer_identities()

def peer_identity_key(self):
return self._servicer_context.peer_identity_key()

def auth_context(self):
return self._servicer_context.auth_context()

def set_compression(self, compression):
return self._servicer_context.set_compression(compression)

async def send_initial_metadata(self, *args, **kwargs):
return await self._servicer_context.send_initial_metadata(*args, **kwargs)

def set_trailing_metadata(self, *args, **kwargs):
return self._servicer_context.set_trailing_metadata(*args, **kwargs)

def trailing_metadata(self):
return self._servicer_context.trailing_metadata()

async def abort(self, code, details = "", trailing_metadata = tuple()):
self._code = code
self._details = details
self._active_span.set_attribute(
SpanAttributes.RPC_GRPC_STATUS_CODE, code.value[0]
)
self._active_span.set_status(
Status(
status_code=StatusCode.ERROR,
description=f"{code}:{details}",
)
)
return await self._servicer_context.abort(code, details, trailing_metadata)

def code(self):
if not hasattr(self._servicer_context, "code"):
raise RuntimeError(
"code() is not supported with the installed version of grpcio"
)
return self._servicer_context.code()

def details(self):
if not hasattr(self._servicer_context, "details"):
raise RuntimeError(
"details() is not supported with the installed version of "
"grpcio"
)
return self._servicer_context.details()

def set_code(self, code):
self._code = code
# use details if we already have it, otherwise the status description
details = self._details or code.value[1]
self._active_span.set_attribute(
SpanAttributes.RPC_GRPC_STATUS_CODE, code.value[0]
)
if code != grpc.StatusCode.OK:
cookiefission marked this conversation as resolved.
Show resolved Hide resolved
self._active_span.set_status(
Status(
status_code=StatusCode.ERROR,
description=f"{code}:{details}",
)
)
return self._servicer_context.set_code(code)

def set_details(self, details):
self._details = details
if self._code != grpc.StatusCode.OK:
self._active_span.set_status(
Status(
status_code=StatusCode.ERROR,
description=f"{self._code}:{details}",
)
)
return self._servicer_context.set_details(details)


class OpenTelemetryAioServerInterceptor(
grpc.aio.ServerInterceptor, OpenTelemetryServerInterceptor
Expand Down Expand Up @@ -66,7 +185,7 @@ async def _unary_interceptor(request_or_iterator, context):
set_status_on_exception=False,
) as span:
# wrap the context
context = _OpenTelemetryServicerContext(context, span)
context = _OpenTelemetryAioServicerContext(context, span)

# And now we run the actual RPC.
try:
Expand All @@ -91,7 +210,7 @@ async def _stream_interceptor(request_or_iterator, context):
context,
set_status_on_exception=False,
) as span:
context = _OpenTelemetryServicerContext(context, span)
context = _OpenTelemetryAioServicerContext(context, span)

try:
async for response in behavior(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,11 @@ async def run_with_test_server(
channel = grpc.aio.insecure_channel(f"localhost:{port:d}")

await server.start()
resp = await runnable(channel)
await server.stop(1000)

try:
resp = await runnable(channel)
finally:
await server.stop(1000)

return resp

Expand Down Expand Up @@ -514,9 +517,75 @@ async def request(channel):
request = Request(client_id=1, request_data=failure_message)
msg = request.SerializeToString()

with testcase.assertRaises(Exception):
with testcase.assertRaises(grpc.RpcError) as cm:
await channel.unary_unary(rpc_call)(msg)

self.assertEqual(cm.exception.code(), grpc.StatusCode.FAILED_PRECONDITION)
self.assertEqual(cm.exception.details(), failure_message)

await run_with_test_server(request, servicer=AbortServicer())

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]

self.assertEqual(span.name, rpc_call)
self.assertIs(span.kind, trace.SpanKind.SERVER)

# Check version and name in span's instrumentation info
self.assertEqualSpanInstrumentationInfo(
span, opentelemetry.instrumentation.grpc
)

# make sure this span errored, with the right status and detail
self.assertEqual(span.status.status_code, StatusCode.ERROR)
self.assertEqual(
span.status.description,
f"{grpc.StatusCode.FAILED_PRECONDITION}:{failure_message}",
)

# Check attributes
self.assertSpanHasAttributes(
span,
{
SpanAttributes.NET_PEER_IP: "[::1]",
SpanAttributes.NET_PEER_NAME: "localhost",
SpanAttributes.RPC_METHOD: "SimpleMethod",
SpanAttributes.RPC_SERVICE: "GRPCTestServer",
SpanAttributes.RPC_SYSTEM: "grpc",
SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.FAILED_PRECONDITION.value[
0
],
},
)

async def test_abort_with_trailing_metadata(self):
"""Check that we can catch an abort properly when trailing_metadata provided"""
rpc_call = "/GRPCTestServer/SimpleMethod"
failure_message = "failure message"

class AbortServicer(GRPCTestServerServicer):
# pylint:disable=C0103
async def SimpleMethod(self, request, context):
metadata = (
("meta", "data"),
)
await context.abort(
grpc.StatusCode.FAILED_PRECONDITION, failure_message, trailing_metadata=metadata
)

testcase = self

async def request(channel):
request = Request(client_id=1, request_data=failure_message)
msg = request.SerializeToString()

with testcase.assertRaises(grpc.RpcError) as cm:
await channel.unary_unary(rpc_call)(msg)

self.assertEqual(cm.exception.code(), grpc.StatusCode.FAILED_PRECONDITION)
self.assertEqual(cm.exception.details(), failure_message)

await run_with_test_server(request, servicer=AbortServicer())

spans_list = self.memory_exporter.get_finished_spans()
Expand Down
Loading