Skip to content

Commit

Permalink
ext: Expect tracer provider instead of tracer in integrations (#602)
Browse files Browse the repository at this point in the history
Standardize the interface that trace providers are specified in integrations, as specified in #585.

Adding a helper to create and return a configured TracerProvider with a the span
processor and the memory exporter

api: Add tracer provider parameter to trace.get_tracer(). This eliminates the need for a helper function and boilerplate code to retrieve the appropriate tracer from a passed tracer_provider.
  • Loading branch information
mauriciovasquezbernal committed Apr 23, 2020
1 parent 232bfdd commit 7cb57c7
Show file tree
Hide file tree
Showing 19 changed files with 163 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
trace.get_tracer_provider().add_span_processor(
SimpleExportSpanProcessor(ConsoleSpanExporter())
)
tracer = trace.get_tracer(__name__)


def run():
Expand All @@ -73,7 +72,7 @@ def run():
# of the code.
with grpc.insecure_channel("localhost:50051") as channel:

channel = intercept_channel(channel, client_interceptor(tracer))
channel = intercept_channel(channel, client_interceptor())

stub = helloworld_pb2_grpc.GreeterStub(channel)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
trace.get_tracer_provider().add_span_processor(
SimpleExportSpanProcessor(ConsoleSpanExporter())
)
tracer = trace.get_tracer(__name__)


class Greeter(helloworld_pb2_grpc.GreeterServicer):
Expand All @@ -75,7 +74,7 @@ def SayHello(self, request, context):
def serve():

server = grpc.server(futures.ThreadPoolExecutor())
server = intercept_server(server, server_interceptor(tracer))
server = intercept_server(server, server_interceptor())

helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server)
server.add_insecure_port("[::]:50051")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
trace.get_tracer_provider().add_span_processor(
SimpleExportSpanProcessor(ConsoleSpanExporter())
)
tracer = trace.get_tracer(__name__)


def make_route_note(message, latitude, longitude):
Expand Down Expand Up @@ -154,7 +153,7 @@ def run():
# used in circumstances in which the with statement does not fit the needs
# of the code.
with grpc.insecure_channel("localhost:50051") as channel:
channel = intercept_channel(channel, client_interceptor(tracer))
channel = intercept_channel(channel, client_interceptor())

stub = route_guide_pb2_grpc.RouteGuideStub(channel)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
trace.get_tracer_provider().add_span_processor(
SimpleExportSpanProcessor(ConsoleSpanExporter())
)
tracer = trace.get_tracer(__name__)


def get_feature(feature_db, point):
Expand Down Expand Up @@ -164,7 +163,7 @@ def RouteChat(self, request_iterator, context):

def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
server = intercept_server(server, server_interceptor(tracer))
server = intercept_server(server, server_interceptor())

route_guide_pb2_grpc.add_RouteGuideServicer_to_server(
RouteGuideServicer(), server
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@
import mysql.connector
import pyodbc
from opentelemetry import trace
from opentelemetry.ext.dbapi import trace_integration
from opentelemetry.trace import TracerProvider
trace.set_tracer_provider(TracerProvider())
tracer = trace.get_tracer(__name__)
# Ex: mysql.connector
trace_integration(tracer, mysql.connector, "connect", "mysql", "sql")
trace_integration(mysql.connector, "connect", "mysql", "sql")
# Ex: pyodbc
trace_integration(tracer, pyodbc, "Connection", "odbc", "sql")
trace_integration(pyodbc, "Connection", "odbc", "sql")
API
---
Expand All @@ -44,13 +45,44 @@

import wrapt

from opentelemetry.trace import SpanKind, Tracer
from opentelemetry.ext.dbapi.version import __version__
from opentelemetry.trace import SpanKind, Tracer, TracerProvider, get_tracer
from opentelemetry.trace.status import Status, StatusCanonicalCode

logger = logging.getLogger(__name__)


def trace_integration(
connect_module: typing.Callable[..., any],
connect_method_name: str,
database_component: str,
database_type: str = "",
connection_attributes: typing.Dict = None,
tracer_provider: typing.Optional[TracerProvider] = None,
):
"""Integrate with DB API library.
https://www.python.org/dev/peps/pep-0249/
Args:
connect_module: Module name where connect method is available.
connect_method_name: The connect method name.
database_component: Database driver name or database name "JDBI", "jdbc", "odbc", "postgreSQL".
database_type: The Database type. For any SQL database, "sql".
connection_attributes: Attribute names for database, port, host and user in Connection object.
tracer_provider: The :class:`TracerProvider` to use. If ommited the current configured one is used.
"""
tracer = get_tracer(__name__, __version__, tracer_provider)
wrap_connect(
tracer,
connect_module,
connect_method_name,
database_component,
database_type,
connection_attributes,
)


def wrap_connect(
tracer: Tracer,
connect_module: typing.Callable[..., any],
connect_method_name: str,
Expand All @@ -71,7 +103,7 @@ def trace_integration(
"""

# pylint: disable=unused-argument
def wrap_connect(
def wrap_connect_(
wrapped: typing.Callable[..., any],
instance: typing.Any,
args: typing.Tuple[any, any],
Expand All @@ -87,7 +119,7 @@ def wrap_connect(

try:
wrapt.wrap_function_wrapper(
connect_module, connect_method_name, wrap_connect
connect_module, connect_method_name, wrap_connect_
)
except Exception as ex: # pylint: disable=broad-except
logger.warning("Failed to integrate with DB API. %s", str(ex))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def setUpClass(cls):
cls._connection = None
cls._cursor = None
cls._tracer = cls.tracer_provider.get_tracer(__name__)
trace_integration(cls._tracer)
trace_integration(cls.tracer_provider)
cls._connection = mysql.connector.connect(
user=MYSQL_USER,
password=MYSQL_PASSWORD,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def setUpClass(cls):
cls._connection = None
cls._cursor = None
cls._tracer = cls.tracer_provider.get_tracer(__name__)
trace_integration(cls._tracer)
trace_integration(cls.tracer_provider)
cls._connection = psycopg2.connect(
dbname=POSTGRES_DB_NAME,
user=POSTGRES_USER,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class TestFunctionalPymongo(TestBase):
def setUpClass(cls):
super().setUpClass()
cls._tracer = cls.tracer_provider.get_tracer(__name__)
trace_integration(cls._tracer)
trace_integration(cls.tracer_provider)
client = MongoClient(
MONGODB_HOST, MONGODB_PORT, serverSelectionTimeoutMS=2000
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
# pylint:disable=no-name-in-module
# pylint:disable=relative-beyond-top-level

from opentelemetry import trace
from opentelemetry.ext.grpc.version import __version__

def client_interceptor(tracer):

def client_interceptor(tracer_provider=None):
"""Create a gRPC client channel interceptor.
Args:
Expand All @@ -29,10 +32,12 @@ def client_interceptor(tracer):
"""
from . import _client

tracer = trace.get_tracer(__name__, __version__, tracer_provider)

return _client.OpenTelemetryClientInterceptor(tracer)


def server_interceptor(tracer):
def server_interceptor(tracer_provider=None):
"""Create a gRPC server interceptor.
Args:
Expand All @@ -43,4 +48,6 @@ def server_interceptor(tracer):
"""
from . import _server

tracer = trace.get_tracer(__name__, __version__, tracer_provider)

return _server.OpenTelemetryServerInterceptor(tracer)
26 changes: 12 additions & 14 deletions ext/opentelemetry-ext-grpc/tests/test_server_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import grpc

import opentelemetry.ext.grpc
from opentelemetry import trace
from opentelemetry.ext.grpc import server_interceptor
from opentelemetry.ext.grpc.grpcext import intercept_server
Expand Down Expand Up @@ -48,15 +49,11 @@ def service(self, handler_call_details):


class TestOpenTelemetryServerInterceptor(TestBase):
def setUp(self):
super().setUp()
self.tracer = self.tracer_provider.get_tracer(__name__)

def test_create_span(self):
"""Check that the interceptor wraps calls with spans server-side."""

# Intercept gRPC calls...
interceptor = server_interceptor(self.tracer)
interceptor = server_interceptor()

# No-op RPC handler
def handler(request, context):
Expand Down Expand Up @@ -87,18 +84,21 @@ def handler(request, context):
self.assertEqual(span.name, "")
self.assertIs(span.kind, trace.SpanKind.SERVER)

# Check version and name in span's instrumentation info
self.check_span_instrumentation_info(span, opentelemetry.ext.grpc)

def test_span_lifetime(self):
"""Check that the span is active for the duration of the call."""

tracer_provider = trace_sdk.TracerProvider()
tracer = tracer_provider.get_tracer(__name__)
interceptor = server_interceptor(tracer)
interceptor = server_interceptor()
tracer = self.tracer_provider.get_tracer(__name__)

# To capture the current span at the time the handler is called
active_span_in_handler = None

def handler(request, context):
nonlocal active_span_in_handler
# The current span is shared among all the tracers.
active_span_in_handler = tracer.get_current_span()
return b""

Expand Down Expand Up @@ -128,10 +128,9 @@ def handler(request, context):
def test_sequential_server_spans(self):
"""Check that sequential RPCs get separate server spans."""

tracer_provider = trace_sdk.TracerProvider()
tracer = tracer_provider.get_tracer(__name__)
tracer = self.tracer_provider.get_tracer(__name__)

interceptor = server_interceptor(tracer)
interceptor = server_interceptor()

# Capture the currently active span in each thread
active_spans_in_handler = []
Expand Down Expand Up @@ -176,10 +175,9 @@ def test_concurrent_server_spans(self):
context.
"""

tracer_provider = trace_sdk.TracerProvider()
tracer = tracer_provider.get_tracer(__name__)
tracer = self.tracer_provider.get_tracer(__name__)

interceptor = server_interceptor(tracer)
interceptor = server_interceptor()

# Capture the currently active span in each thread
active_spans_in_handler = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@
from opentelemetry.ext.mysql import trace_integration
trace.set_tracer_provider(TracerProvider())
tracer = trace.get_tracer(__name__)
trace_integration(tracer)
trace_integration()
cnx = mysql.connector.connect(database='MySQL_Database')
cursor = cnx.cursor()
cursor.execute("INSERT INTO test (testField) VALUES (123)"
Expand All @@ -42,23 +41,29 @@
---
"""

import typing

import mysql.connector

from opentelemetry.ext.dbapi import trace_integration as db_integration
from opentelemetry.trace import Tracer
from opentelemetry.ext.dbapi import wrap_connect
from opentelemetry.ext.mysql.version import __version__
from opentelemetry.trace import TracerProvider, get_tracer


def trace_integration(tracer: Tracer):
def trace_integration(tracer_provider: typing.Optional[TracerProvider] = None):
"""Integrate with MySQL Connector/Python library.
https://dev.mysql.com/doc/connector-python/en/
"""

tracer = get_tracer(__name__, __version__, tracer_provider)

connection_attributes = {
"database": "database",
"port": "server_port",
"host": "server_host",
"user": "user",
}
db_integration(
wrap_connect(
tracer,
mysql.connector,
"connect",
Expand Down
33 changes: 28 additions & 5 deletions ext/opentelemetry-ext-mysql/tests/test_mysql_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,43 @@

import mysql.connector

from opentelemetry.ext.mysql import trace_integration
import opentelemetry.ext.mysql
from opentelemetry.sdk import resources
from opentelemetry.test.test_base import TestBase


class TestMysqlIntegration(TestBase):
def test_trace_integration(self):
tracer = self.tracer_provider.get_tracer(__name__)
with mock.patch("mysql.connector.connect") as mock_connect:
mock_connect.get.side_effect = mysql.connector.MySQLConnection()
opentelemetry.ext.mysql.trace_integration()
cnx = mysql.connector.connect(database="test")
cursor = cnx.cursor()
query = "SELECT * FROM test"
cursor.execute(query)

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]

# Check version and name in span's instrumentation info
self.check_span_instrumentation_info(span, opentelemetry.ext.mysql)

def test_custom_tracer_provider(self):
resource = resources.Resource.create({})
result = self.create_tracer_provider(resource=resource)
tracer_provider, exporter = result

with mock.patch("mysql.connector.connect") as mock_connect:
mock_connect.get.side_effect = mysql.connector.MySQLConnection()
trace_integration(tracer)
opentelemetry.ext.mysql.trace_integration(tracer_provider)
cnx = mysql.connector.connect(database="test")
cursor = cnx.cursor()
query = "SELECT * FROM test"
cursor.execute(query)
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

span_list = exporter.get_finished_spans()
self.assertEqual(len(span_list), 1)
span = span_list[0]

self.assertIs(span.resource, resource)
Loading

0 comments on commit 7cb57c7

Please sign in to comment.