Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ext: Expect tracer provider instead of tracer in integrations #602

Merged
merged 14 commits into from
Apr 23, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
trace.get_tracer_provider().add_span_processor(
SimpleExportSpanProcessor(ConsoleSpanExporter())
)
tracer = trace.get_tracer(__name__)


def run():
Expand All @@ -73,7 +72,7 @@ def run():
# of the code.
with grpc.insecure_channel("localhost:50051") as channel:

channel = intercept_channel(channel, client_interceptor(tracer))
channel = intercept_channel(channel, client_interceptor())

stub = helloworld_pb2_grpc.GreeterStub(channel)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
trace.get_tracer_provider().add_span_processor(
SimpleExportSpanProcessor(ConsoleSpanExporter())
)
tracer = trace.get_tracer(__name__)


class Greeter(helloworld_pb2_grpc.GreeterServicer):
Expand All @@ -75,7 +74,7 @@ def SayHello(self, request, context):
def serve():

server = grpc.server(futures.ThreadPoolExecutor())
server = intercept_server(server, server_interceptor(tracer))
server = intercept_server(server, server_interceptor())

helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server)
server.add_insecure_port("[::]:50051")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
trace.get_tracer_provider().add_span_processor(
SimpleExportSpanProcessor(ConsoleSpanExporter())
)
tracer = trace.get_tracer(__name__)


def make_route_note(message, latitude, longitude):
Expand Down Expand Up @@ -154,7 +153,7 @@ def run():
# used in circumstances in which the with statement does not fit the needs
# of the code.
with grpc.insecure_channel("localhost:50051") as channel:
channel = intercept_channel(channel, client_interceptor(tracer))
channel = intercept_channel(channel, client_interceptor())

stub = route_guide_pb2_grpc.RouteGuideStub(channel)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
trace.get_tracer_provider().add_span_processor(
SimpleExportSpanProcessor(ConsoleSpanExporter())
)
tracer = trace.get_tracer(__name__)


def get_feature(feature_db, point):
Expand Down Expand Up @@ -164,7 +163,7 @@ def RouteChat(self, request_iterator, context):

def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
server = intercept_server(server, server_interceptor(tracer))
server = intercept_server(server, server_interceptor())

route_guide_pb2_grpc.add_RouteGuideServicer_to_server(
RouteGuideServicer(), server
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
from opentelemetry.trace import TracerProvider

trace.set_tracer_provider(TracerProvider())
tracer = trace.get_tracer(__name__)

# Ex: mysql.connector
trace_integration(tracer, mysql.connector, "connect", "mysql", "sql")
trace_integration(mysql.connector, "connect", "mysql", "sql")
# Ex: pyodbc
trace_integration(tracer, pyodbc, "Connection", "odbc", "sql")
trace_integration(pyodbc, "Connection", "odbc", "sql")

API
---
Expand All @@ -44,13 +44,52 @@

import wrapt

from opentelemetry.trace import SpanKind, Tracer
from opentelemetry.ext.dbapi.version import __version__
from opentelemetry.trace import (
SpanKind,
Tracer,
TracerProvider,
get_tracer_provider,
)
from opentelemetry.trace.status import Status, StatusCanonicalCode

logger = logging.getLogger(__name__)


def trace_integration(
connect_module: typing.Callable[..., any],
connect_method_name: str,
database_component: str,
database_type: str = "",
connection_attributes: typing.Dict = None,
tracer_provider: typing.Optional[TracerProvider] = None,
):
"""Integrate with DB API library.
https://www.python.org/dev/peps/pep-0249/

Args:
connect_module: Module name where connect method is available.
connect_method_name: The connect method name.
database_component: Database driver name or database name "JDBI", "jdbc", "odbc", "postgreSQL".
database_type: The Database type. For any SQL database, "sql".
connection_attributes: Attribute names for database, port, host and user in Connection object.
tracer_provider: The :class:`TracerProvider` to use. If ommited the current configured one is used.
"""
if tracer_provider is None:
tracer_provider = get_tracer_provider()

tracer = tracer_provider.get_tracer(__name__, __version__)
wrap_connect(
tracer,
connect_module,
connect_method_name,
database_component,
database_type,
connection_attributes,
)


def wrap_connect(
tracer: Tracer,
connect_module: typing.Callable[..., any],
connect_method_name: str,
Expand All @@ -71,7 +110,7 @@ def trace_integration(
"""

# pylint: disable=unused-argument
def wrap_connect(
def wrap_connect_(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a slight style thing, but it's a little harder for me to spot a trailing underscore rather than a leading underscore.

wrapped: typing.Callable[..., any],
instance: typing.Any,
args: typing.Tuple[any, any],
Expand All @@ -87,7 +126,7 @@ def wrap_connect(

try:
wrapt.wrap_function_wrapper(
connect_module, connect_method_name, wrap_connect
connect_module, connect_method_name, wrap_connect_
)
except Exception as ex: # pylint: disable=broad-except
logger.warning("Failed to integrate with DB API. %s", str(ex))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def setUpClass(cls):
cls._connection = None
cls._cursor = None
cls._tracer = cls.tracer_provider.get_tracer(__name__)
trace_integration(cls._tracer)
trace_integration(cls.tracer_provider)
cls._connection = mysql.connector.connect(
user=MYSQL_USER,
password=MYSQL_PASSWORD,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def setUpClass(cls):
cls._connection = None
cls._cursor = None
cls._tracer = cls.tracer_provider.get_tracer(__name__)
trace_integration(cls._tracer)
trace_integration(cls.tracer_provider)
cls._connection = psycopg2.connect(
dbname=POSTGRES_DB_NAME,
user=POSTGRES_USER,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class TestFunctionalPymongo(TestBase):
def setUpClass(cls):
super().setUpClass()
cls._tracer = cls.tracer_provider.get_tracer(__name__)
trace_integration(cls._tracer)
trace_integration(cls.tracer_provider)
client = MongoClient(
MONGODB_HOST, MONGODB_PORT, serverSelectionTimeoutMS=2000
)
Expand Down
17 changes: 15 additions & 2 deletions ext/opentelemetry-ext-grpc/src/opentelemetry/ext/grpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
# pylint:disable=no-name-in-module
# pylint:disable=relative-beyond-top-level

from opentelemetry import trace
from opentelemetry.ext.grpc.version import __version__

def client_interceptor(tracer):

def client_interceptor(tracer_provider=None):
"""Create a gRPC client channel interceptor.

Args:
Expand All @@ -29,10 +32,15 @@ def client_interceptor(tracer):
"""
from . import _client

if tracer_provider is None:
tracer_provider = trace.get_tracer_provider()

tracer = tracer_provider.get_tracer(__name__, __version__)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this pattern will occur more often, would it make sense to add an utility function setup_tracer(tracer_provider: Optional[TracerProvider], name, version) to the API to encapsulate these three lines?

Copy link
Member Author

@mauriciovasquezbernal mauriciovasquezbernal Apr 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking if it could make sense to add an optional TracerProvider parameter to get_tracer(). It'll be the same setup_tracer you're proposing.


return _client.OpenTelemetryClientInterceptor(tracer)


def server_interceptor(tracer):
def server_interceptor(tracer_provider=None):
"""Create a gRPC server interceptor.

Args:
Expand All @@ -43,4 +51,9 @@ def server_interceptor(tracer):
"""
from . import _server

if tracer_provider is None:
tracer_provider = trace.get_tracer_provider()

tracer = tracer_provider.get_tracer(__name__, __version__)

return _server.OpenTelemetryServerInterceptor(tracer)
26 changes: 12 additions & 14 deletions ext/opentelemetry-ext-grpc/tests/test_server_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import grpc

import opentelemetry.ext.grpc
from opentelemetry import trace
from opentelemetry.ext.grpc import server_interceptor
from opentelemetry.ext.grpc.grpcext import intercept_server
Expand Down Expand Up @@ -48,15 +49,11 @@ def service(self, handler_call_details):


class TestOpenTelemetryServerInterceptor(TestBase):
def setUp(self):
super().setUp()
self.tracer = self.tracer_provider.get_tracer(__name__)

def test_create_span(self):
"""Check that the interceptor wraps calls with spans server-side."""

# Intercept gRPC calls...
interceptor = server_interceptor(self.tracer)
interceptor = server_interceptor()

# No-op RPC handler
def handler(request, context):
Expand Down Expand Up @@ -87,18 +84,21 @@ def handler(request, context):
self.assertEqual(span.name, "")
self.assertIs(span.kind, trace.SpanKind.SERVER)

# Check version and name in span's instrumentation info
self.check_span_instrumentation_info(span, opentelemetry.ext.grpc)

def test_span_lifetime(self):
"""Check that the span is active for the duration of the call."""

tracer_provider = trace_sdk.TracerProvider()
tracer = tracer_provider.get_tracer(__name__)
interceptor = server_interceptor(tracer)
interceptor = server_interceptor()
tracer = self.tracer_provider.get_tracer(__name__)

# To capture the current span at the time the handler is called
active_span_in_handler = None

def handler(request, context):
nonlocal active_span_in_handler
# The current span is shared among all the tracers.
active_span_in_handler = tracer.get_current_span()
return b""

Expand Down Expand Up @@ -128,10 +128,9 @@ def handler(request, context):
def test_sequential_server_spans(self):
"""Check that sequential RPCs get separate server spans."""

tracer_provider = trace_sdk.TracerProvider()
tracer = tracer_provider.get_tracer(__name__)
tracer = self.tracer_provider.get_tracer(__name__)

interceptor = server_interceptor(tracer)
interceptor = server_interceptor()

# Capture the currently active span in each thread
active_spans_in_handler = []
Expand Down Expand Up @@ -176,10 +175,9 @@ def test_concurrent_server_spans(self):
context.
"""

tracer_provider = trace_sdk.TracerProvider()
tracer = tracer_provider.get_tracer(__name__)
tracer = self.tracer_provider.get_tracer(__name__)

interceptor = server_interceptor(tracer)
interceptor = server_interceptor()

# Capture the currently active span in each thread
active_spans_in_handler = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
trace.set_tracer_provider(TracerProvider())
tracer = trace.get_tracer(__name__)

trace_integration(tracer)
trace_integration()
cnx = mysql.connector.connect(database='MySQL_Database')
cursor = cnx.cursor()
cursor.execute("INSERT INTO test (testField) VALUES (123)"
Expand All @@ -41,23 +41,32 @@
---
"""

import typing

import mysql.connector

from opentelemetry.ext.dbapi import trace_integration as db_integration
from opentelemetry.trace import Tracer
from opentelemetry.ext.dbapi import wrap_connect
from opentelemetry.ext.mysql.version import __version__
from opentelemetry.trace import TracerProvider, get_tracer_provider


def trace_integration(tracer: Tracer):
def trace_integration(tracer_provider: typing.Optional[TracerProvider] = None):
"""Integrate with MySQL Connector/Python library.
https://dev.mysql.com/doc/connector-python/en/
"""

if tracer_provider is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason not to call the original db_integration with the tracer_provider as a parameter since it's doing the same check?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that it'll create the tracer with the version and name from the db-api module. I had to split the logic from opentelemetry.ext.dbapi.trace_integration into two different functions, one receiving a tracer provider (to be used by the user), and other one receiving a tracer (to be used by other integrations).

tracer_provider = get_tracer_provider()

tracer = tracer_provider.get_tracer(__name__, __version__)

connection_attributes = {
"database": "database",
"port": "server_port",
"host": "server_host",
"user": "user",
}
db_integration(
wrap_connect(
tracer,
mysql.connector,
"connect",
Expand Down
34 changes: 29 additions & 5 deletions ext/opentelemetry-ext-mysql/tests/test_mysql_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,44 @@

import mysql.connector

from opentelemetry.ext.mysql import trace_integration
import opentelemetry.ext.mysql
from opentelemetry.sdk import resources
from opentelemetry.test.test_base import TestBase


class TestMysqlIntegration(TestBase):
def test_trace_integration(self):
tracer = self.tracer_provider.get_tracer(__name__)
with mock.patch("mysql.connector.connect") as mock_connect:
mock_connect.get.side_effect = mysql.connector.MySQLConnection()
opentelemetry.ext.mysql.trace_integration()
cnx = mysql.connector.connect(database="test")
cursor = cnx.cursor()
query = "SELECT * FROM test"
cursor.execute(query)

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]
# TODO: Add more tests?
mauriciovasquezbernal marked this conversation as resolved.
Show resolved Hide resolved

# Check version and name in span's instrumentation info
self.check_span_instrumentation_info(span, opentelemetry.ext.mysql)

def test_custom_tracer_provider(self):
resource = resources.Resource.create({})
result = self.create_tracer_provider(resource=resource)
tracer_provider, exporter = result

with mock.patch("mysql.connector.connect") as mock_connect:
mock_connect.get.side_effect = mysql.connector.MySQLConnection()
trace_integration(tracer)
opentelemetry.ext.mysql.trace_integration(tracer_provider)
cnx = mysql.connector.connect(database="test")
cursor = cnx.cursor()
query = "SELECT * FROM test"
cursor.execute(query)
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

span_list = exporter.get_finished_spans()
self.assertEqual(len(span_list), 1)
span = span_list[0]

self.assertIs(span.resource, resource)
Loading