Skip to content

Commit

Permalink
Send early spans at the beginning to cover for timeouts and failures
Browse files Browse the repository at this point in the history
  • Loading branch information
RafalSumislawski committed Jul 16, 2024
1 parent c88ac19 commit 97eb6cc
Showing 1 changed file with 93 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,13 @@ def custom_event_context_extractor(lambda_event):
Span,
SpanKind,
Link,
Tracer,
TracerProvider,
get_current_span,
get_tracer,
get_tracer_provider,
set_span_in_context
set_span_in_context,
use_span,
)
from opentelemetry.trace.propagation import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID
Expand Down Expand Up @@ -329,7 +331,7 @@ def _set_api_gateway_v2_proxy_attributes(
def _instrument(
wrapped_module_name,
wrapped_function_name,
flush_timeout,
flush_timeout: int,
event_context_extractor: Callable[[Any], Context],
tracer_provider: TracerProvider = None,
disable_aws_context_propagation: bool = False,
Expand Down Expand Up @@ -372,6 +374,8 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches

tracer = get_tracer(__name__, __version__, tracer_provider)

triggerSpan = None

apiGwSpan = None
try:
# If the request came from an API Gateway, extract http attributes from the event
Expand All @@ -394,6 +398,7 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches

apiGwSpan.set_attribute(SpanAttributes.FAAS_TRIGGER, "http")

triggerSpan = apiGwSpan
parent_context = set_span_in_context(apiGwSpan)
except Exception as ex:
pass
Expand All @@ -411,6 +416,7 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
s3TriggerSpan.set_attribute(SpanAttributes.FAAS_TRIGGER, "datasource")
s3TriggerSpan.set_attribute("faas.trigger.type", "S3")

triggerSpan = s3TriggerSpan
parent_context = set_span_in_context(s3TriggerSpan)

if lambda_event["Records"][0].get("s3"):
Expand Down Expand Up @@ -451,7 +457,8 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
sqsTriggerSpan.set_attribute(SpanAttributes.MESSAGING_DESTINATION, queue_url.split(":")[-1])
except IndexError:
pass


triggerSpan = sqsTriggerSpan
parent_context = set_span_in_context(sqsTriggerSpan)

if lambda_event["Records"][0].get("body"):
Expand Down Expand Up @@ -496,6 +503,7 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
except IndexError:
pass

triggerSpan = snsTriggerSpan
parent_context = set_span_in_context(snsTriggerSpan)

if lambda_event["Records"][0]["Sns"] and lambda_event["Records"][0]["Sns"].get("Message"):
Expand Down Expand Up @@ -542,7 +550,7 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
except IndexError:
pass


triggerSpan = kinesisTriggerSpan
parent_context = set_span_in_context(kinesisTriggerSpan)

if lambda_event["Records"][0]["kinesis"] and lambda_event["Records"][0]["kinesis"].get("data"):
Expand All @@ -568,6 +576,7 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
dynamoTriggerSpan.set_attribute(SpanAttributes.FAAS_TRIGGER, "datasource")
dynamoTriggerSpan.set_attribute("faas.trigger.type", "Dynamo DB")

triggerSpan = dynamoTriggerSpan
parent_context = set_span_in_context(dynamoTriggerSpan)

if lambda_event["Records"][0].get("dynamodb"):
Expand All @@ -589,6 +598,7 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
cognitoTriggerSpan.set_attribute(SpanAttributes.FAAS_TRIGGER, "datasource")
cognitoTriggerSpan.set_attribute("faas.trigger.type", "Cognito")

triggerSpan = cognitoTriggerSpan
parent_context = set_span_in_context(cognitoTriggerSpan)

if lambda_event["datasetRecords"]:
Expand Down Expand Up @@ -617,6 +627,8 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
eventBridgeTriggerSpan.set_attribute(SpanAttributes.FAAS_TRIGGER, "pubsub")
eventBridgeTriggerSpan.set_attribute("faas.trigger.type", "EventBridge")
eventBridgeTriggerSpan.set_attribute("aws.event.bridge.trigger.source", lambda_event.get("source"))

triggerSpan = eventBridgeTriggerSpan
parent_context = set_span_in_context(eventBridgeTriggerSpan)

eventBridgeTriggerSpan.set_attribute(
Expand All @@ -625,12 +637,23 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
)
except Exception as ex:
pass

if triggerSpan is not None:
triggerSpan.set_attribute("cx.internal.span.role", "trigger")

try:
with tracer.start_as_current_span(
invocationSpan = tracer.start_span(
name=orig_handler_name,
context=parent_context,
kind=span_kind,
)
invocationSpan.set_attribute("cx.internal.span.role", "invocation")

_sendEarlySpans(flush_timeout, tracer, tracer_provider, meter_provider, triggerSpan, invocationSpan)

with use_span(
span=invocationSpan,
end_on_exit=True,
) as span:
if span.is_recording():
lambda_context = args[1]
Expand Down Expand Up @@ -808,66 +831,15 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
)
except Exception:
pass
eventBridgeTriggerSpan.end()

now = time.time()
_tracer_provider = tracer_provider or get_tracer_provider()
if hasattr(_tracer_provider, "force_flush"):
try:
# NOTE: `force_flush` before function quit in case of Lambda freeze.
_tracer_provider.force_flush(flush_timeout)
except Exception: # pylint: disable=broad-except
logger.exception("TracerProvider failed to flush traces")
else:
logger.warning(
"TracerProvider was missing `force_flush` method. This is necessary in case of a Lambda freeze and would exist in the OTel SDK implementation."
)
eventBridgeTriggerSpan.end()

except Exception as e:
if apiGwSpan is not None:
apiGwSpan.end()
if s3TriggerSpan is not None:
s3TriggerSpan.end()
if sqsTriggerSpan is not None:
sqsTriggerSpan.end()
if snsTriggerSpan is not None:
snsTriggerSpan.end()
if dynamoTriggerSpan is not None:
dynamoTriggerSpan.end()
if cognitoTriggerSpan is not None:
cognitoTriggerSpan.end()
if eventBridgeTriggerSpan is not None:
eventBridgeTriggerSpan.end()
if kinesisTriggerSpan is not None:
kinesisTriggerSpan.end()

now = time.time()
_tracer_provider = tracer_provider or get_tracer_provider()
if hasattr(_tracer_provider, "force_flush"):
try:
# NOTE: `force_flush` before function quit in case of Lambda freeze.
_tracer_provider.force_flush(flush_timeout)
except Exception: # pylint: disable=broad-except
logger.warning(
"TracerProvider was missing `force_flush` method. This is necessary in case of a Lambda freeze and would exist in the OTel SDK implementation."
)
# pass

if triggerSpan is not None:
triggerSpan.end()
raise e

_meter_provider = meter_provider or get_meter_provider()
if hasattr(_meter_provider, "force_flush"):
rem = flush_timeout - (time.time() - now) * 1000
if rem > 0:
try:
# NOTE: `force_flush` before function quit in case of Lambda freeze.
_meter_provider.force_flush(rem)
except Exception: # pylint: disable=broad-except
logger.exception("MeterProvider failed to flush metrics")
else:
logger.warning(
"MeterProvider was missing `force_flush` method. This is necessary in case of a Lambda freeze and would exist in the OTel SDK implementation."
)
finally:
_flush(flush_timeout, tracer_provider, meter_provider)

return result

Expand All @@ -877,6 +849,66 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
_instrumented_lambda_handler_call,
)

def _sendEarlySpans(
flush_timeout: int,
tracer: Tracer,
tracer_provider: TracerProvider,
meter_provider: MeterProvider,
trigger_span: Span,
invocation_span: Span,
) -> None:
if trigger_span is not None:
early_trigger = _createEarlySpan(tracer, trigger_span)
early_trigger.end()

if invocation_span is not None:
early_invocation = _createEarlySpan(tracer, invocation_span)
early_invocation.end()

_flush(flush_timeout, tracer_provider, meter_provider)

def _createEarlySpan(
tracer: Tracer,
span: Span,
) -> Span:
early_span = tracer.start_span(name=span.name, kind=span.kind, attributes=span.attributes)
early_span.set_attribute("cx.internal.span.state", "early")
early_span.set_attribute("cx.internal.trace.id", span.get_span_context().trace_id)
early_span.set_attribute("cx.internal.span.id", span.get_span_context().span_id)
return early_span

def _flush(
flush_timeout: int,
tracer_provider: TracerProvider = None,
meter_provider: MeterProvider = None,
) -> None:
now = time.time()
_tracer_provider = tracer_provider or get_tracer_provider()
if hasattr(_tracer_provider, "force_flush"):
try:
# NOTE: `force_flush` before function quit in case of Lambda freeze.
_tracer_provider.force_flush(flush_timeout)
except Exception: # pylint: disable=broad-except
logger.exception("TracerProvider failed to flush traces")
else:
logger.warning(
"TracerProvider was missing `force_flush` method. This is necessary in case of a Lambda freeze and would exist in the OTel SDK implementation."
)

_meter_provider = meter_provider or get_meter_provider()
if hasattr(_meter_provider, "force_flush"):
rem = flush_timeout - (time.time() - now) * 1000
if rem > 0:
try:
# NOTE: `force_flush` before function quit in case of Lambda freeze.
_meter_provider.force_flush(rem)
except Exception: # pylint: disable=broad-except
logger.exception("MeterProvider failed to flush metrics")
else:
logger.warning(
"MeterProvider was missing `force_flush` method. This is necessary in case of a Lambda freeze and would exist in the OTel SDK implementation."
)


class AwsLambdaInstrumentor(BaseInstrumentor):
def instrumentation_dependencies(self) -> Collection[str]:
Expand Down

0 comments on commit 97eb6cc

Please sign in to comment.