Skip to content

Commit

Permalink
Support other propagators in Python Layer
Browse files Browse the repository at this point in the history
  • Loading branch information
NathanielRN committed Oct 5, 2021
1 parent 1280fb9 commit 38bf31c
Showing 1 changed file with 102 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,30 @@ def lambda_handler(event, context):
import logging
import os
from importlib import import_module
from typing import Collection
from wrapt import wrap_function_wrapper
from typing import Any, Collection

from opentelemetry.context.context import Context

# TODO: aws propagator
from opentelemetry.sdk.extension.aws.trace.propagation.aws_xray_format import (
AwsXRayFormat,
)
from opentelemetry.instrumentation.aws_lambda.package import _instruments
from opentelemetry.instrumentation.aws_lambda.version import __version__
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.utils import unwrap
from opentelemetry.propagate import get_global_textmap

from opentelemetry.sdk.extension.aws.trace.propagation.aws_xray_format import (
AwsXRayFormat,
TRACE_HEADER_KEY,
)
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import SpanKind, get_tracer, get_tracer_provider
from opentelemetry.trace import (
SpanKind,
Tracer,
get_tracer,
get_tracer_provider,
)
from opentelemetry.trace.propagation import get_current_span

from wrapt import wrap_function_wrapper

logger = logging.getLogger(__name__)

Expand All @@ -69,15 +80,22 @@ def instrumentation_dependencies(self) -> Collection[str]:
return _instruments

def _instrument(self, **kwargs):
"""Instruments Lambda Handlers on AWS Lambda
"""Instruments Lambda Handlers on AWS Lambda.
Read more about how instrumentation is decided:
https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/instrumentation/aws-lambda.md#instrumenting-aws-lambda
Args:
**kwargs: Optional arguments
``tracer_provider``: a TracerProvider, defaults to global
``event_context_extractor``: a method which takes the Lambda
Event as input and provides the object which contains the
context as output (usually HTTP headers)
"""
tracer = get_tracer(
__name__, __version__, kwargs.get("tracer_provider")
)
event_context_extractor = kwargs.get("event_context_extractor")

lambda_handler = os.environ.get(
"ORIG_HANDLER", os.environ.get("_HANDLER")
Expand All @@ -87,7 +105,10 @@ def _instrument(self, **kwargs):
self._wrapped_function_name = wrapped_names[1]

_instrument(
tracer, self._wrapped_module_name, self._wrapped_function_name
tracer,
self._wrapped_module_name,
self._wrapped_function_name,
event_context_extractor,
)

def _uninstrument(self, **kwargs):
Expand All @@ -97,16 +118,83 @@ def _uninstrument(self, **kwargs):
)


def _instrument(tracer, wrapped_module_name, wrapped_function_name):
def _default_event_context_extractor(lambda_event: Any) -> Context:
"""Default way of extracting the context from the Lambda Event.
Assumes the Lambda Event is a map with the headers under the 'headers' key.
This is the mapping to use when the Lambda is invoked by an API Gateway
REST API where API Gateway is acting as a pure proxy for the request.
Args:
lambda_event: user-defined, so it could be anything, but this
method counts it being a map with a 'headers' key
Returns:
A Context with configuration found in the carrier.
"""
try:
headers = lambda_event["headers"]
except (TypeError, KeyError):
logger.warning("Failed to extract context from Lambda Event.")
headers = {}
return get_global_textmap().extract(headers)


def _instrument(
tracer: Tracer,
wrapped_module_name,
wrapped_function_name,
event_context_extractor=None,
):
def _determine_parent_context(lambda_event: Any) -> Context:
"""Determine the parent context for the current Lambda invocation.
Refer:
https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/instrumentation/aws-lambda.md#determining-the-parent-of-a-span
Args:
lambda_event: user-defined, so it could be anything, but this
method counts it being a map with a 'headers' key
Returns:
A Context with configuration found in the carrier.
"""
parent_context = None

xray_env_var = os.environ.get("_X_AMZN_TRACE_ID")

if xray_env_var:
parent_context = AwsXRayFormat().extract(
{TRACE_HEADER_KEY: xray_env_var}
)

if (
parent_context
and get_current_span(parent_context)
.get_span_context()
.trace_flags.sampled
):
return parent_context

logger.debug(
"X-Ray propagation failed, use user-configured propagators to extract context from Lambda Event."
)

if event_context_extractor:
parent_context = event_context_extractor(lambda_event)
else:
parent_context = _default_event_context_extractor(lambda_event)

return parent_context

def _instrumented_lambda_handler_call(call_wrapped, instance, args, kwargs):
orig_handler_name = ".".join(
[wrapped_module_name, wrapped_function_name]
)

# TODO: enable propagate from AWS by env variable
xray_trace_id = os.environ.get("_X_AMZN_TRACE_ID", "")
propagator = AwsXRayFormat()
parent_context = propagator.extract({"X-Amzn-Trace-Id": xray_trace_id})
lambda_event = args[0]

parent_context = _determine_parent_context(lambda_event)

with tracer.start_as_current_span(
name=orig_handler_name, context=parent_context, kind=SpanKind.SERVER
Expand Down

0 comments on commit 38bf31c

Please sign in to comment.