diff --git a/aws_xray_sdk/core/models/mimic_segment.py b/aws_xray_sdk/core/models/mimic_segment.py index 4c58ff69..4cbb0bc9 100644 --- a/aws_xray_sdk/core/models/mimic_segment.py +++ b/aws_xray_sdk/core/models/mimic_segment.py @@ -1,10 +1,14 @@ from .segment import Segment +from .facade_segment import FacadeSegment from ..exceptions.exceptions import MimicSegmentInvalidException class MimicSegment(Segment): """ - The MimicSegment is an entity that mimics a segment for the use of the serverless context. + The MimicSegment is an entity that mimics a segment. It's primary use is for a special-case + in Lambda; specifically, for the Serverless Design Pattern. It is not recommended to use + this MimicSegment for any other purpose. + When the MimicSegment is generated, its parent segment is assigned to be the FacadeSegment generated by the Lambda Environment. Upon serialization and transmission of the MimicSegment, it is converted to a locally-namespaced, subsegment. This is only done during serialization. @@ -16,8 +20,8 @@ class MimicSegment(Segment): is available to be used. """ - def __init__(self, facade_segment=None, original_segment=None): - if not original_segment or not facade_segment: + def __init__(self, facade_segment, original_segment): + if not issubclass(type(original_segment), Segment) or type(facade_segment) != FacadeSegment: raise MimicSegmentInvalidException("Invalid MimicSegment construction. " "Please put in the original segment and the facade segment.") super(MimicSegment, self).__init__(name=original_segment.name, entityid=original_segment.id, @@ -27,7 +31,7 @@ def __init__(self, facade_segment=None, original_segment=None): def __getstate__(self): """ Used during serialization. We mark the subsegment properties to let the dataplane know - that we want the mimic segment to be represented as a subsegment. + that we want the mimic segment to be transformed as a subsegment. """ properties = super(MimicSegment, self).__getstate__() properties['type'] = 'subsegment' diff --git a/aws_xray_sdk/core/serverless_context.py b/aws_xray_sdk/core/serverless_lambda_context.py similarity index 87% rename from aws_xray_sdk/core/serverless_context.py rename to aws_xray_sdk/core/serverless_lambda_context.py index 2fff9fe5..01f75e3f 100644 --- a/aws_xray_sdk/core/serverless_context.py +++ b/aws_xray_sdk/core/serverless_lambda_context.py @@ -6,13 +6,12 @@ from .models.mimic_segment import MimicSegment from .context import CXT_MISSING_STRATEGY_KEY from .lambda_launcher import LambdaContext -from .context import Context log = logging.getLogger(__name__) -class ServerlessContext(LambdaContext): +class ServerlessLambdaContext(LambdaContext): """ Context used specifically for running middlewares on Lambda through the Serverless design. This context is built on top of the LambdaContext, but @@ -23,7 +22,7 @@ class ServerlessContext(LambdaContext): ensures that FacadeSegments exist through underlying calls to _refresh_context(). """ def __init__(self, context_missing='RUNTIME_ERROR'): - super(ServerlessContext, self).__init__() + super(ServerlessLambdaContext, self).__init__() strategy = os.getenv(CXT_MISSING_STRATEGY_KEY, context_missing) self._context_missing = strategy @@ -38,7 +37,7 @@ def put_segment(self, segment): parent_facade_segment = self.__get_facade_entity() # type: FacadeSegment mimic_segment = MimicSegment(parent_facade_segment, segment) parent_facade_segment.add_subsegment(mimic_segment) - Context.put_segment(self, mimic_segment) + super(LambdaContext, self).put_segment(mimic_segment) def end_segment(self, end_time=None): """ @@ -46,7 +45,7 @@ def end_segment(self, end_time=None): """ # Close the last mimic segment opened then remove it from our facade segment. mimic_segment = self.get_trace_entity() - Context.end_segment(self, end_time) + super(LambdaContext, self).end_segment(end_time) if type(mimic_segment) == MimicSegment: # The facade segment can only hold mimic segments. facade_segment = self.__get_facade_entity() @@ -58,7 +57,7 @@ def put_subsegment(self, subsegment): another subsegment if they are the last opened entity. :param subsegment: The subsegment to to be added as a subsegment. """ - Context.put_subsegment(self, subsegment) + super(LambdaContext, self).put_subsegment(subsegment) def end_subsegment(self, end_time=None): """ @@ -69,7 +68,7 @@ def end_subsegment(self, end_time=None): system time will be used. :return: True on success, false if no parent mimic segment/subsegment is found. """ - return Context.end_subsegment(self, end_time) + return super(LambdaContext, self).end_subsegment(end_time) def __get_facade_entity(self): """ @@ -92,12 +91,12 @@ def get_trace_entity(self): # Call to Context.get_trace_entity() returns the latest mimic segment/subsegment if they exist. # Otherwise, returns None through the following way: # No mimic segment/subsegment exists so Context calls LambdaContext's handle_context_missing(). - # By default, Lambda's method returns no-op, so it will return None to ServerlessContext. + # By default, Lambda's method returns no-op, so it will return None to ServerlessLambdaContext. # Take that None as an indication to return the rightful handle_context_missing(), otherwise # return the entity. - entity = Context.get_trace_entity(self) + entity = super(LambdaContext, self).get_trace_entity() if entity is None: - return Context.handle_context_missing(self) + return super(LambdaContext, self).handle_context_missing() else: return entity @@ -116,11 +115,11 @@ def set_trace_entity(self, trace_entity): # behavior would be invoked. mimic_segment = trace_entity - Context.set_trace_entity(self, mimic_segment) + super(LambdaContext, self).set_trace_entity(mimic_segment) self.__get_facade_entity().subsegments = [mimic_segment] def _is_subsegment(self, entity): - return super(ServerlessContext, self)._is_subsegment(entity) and type(entity) != MimicSegment + return super(ServerlessLambdaContext, self)._is_subsegment(entity) and type(entity) != MimicSegment @property def context_missing(self): diff --git a/aws_xray_sdk/ext/django/middleware.py b/aws_xray_sdk/ext/django/middleware.py index 39ffffb0..07b732d9 100644 --- a/aws_xray_sdk/ext/django/middleware.py +++ b/aws_xray_sdk/ext/django/middleware.py @@ -1,9 +1,9 @@ import logging from aws_xray_sdk.core import xray_recorder -from aws_xray_sdk.core.lambda_launcher import check_in_lambda +from aws_xray_sdk.core.lambda_launcher import check_in_lambda, LambdaContext from aws_xray_sdk.core.models import http -from aws_xray_sdk.core.serverless_context import ServerlessContext +from aws_xray_sdk.core.serverless_lambda_context import ServerlessLambdaContext from aws_xray_sdk.core.utils import stacktrace from aws_xray_sdk.ext.util import calculate_sampling_decision, \ calculate_segment_name, construct_xray_header, prepare_response_header @@ -28,9 +28,11 @@ def __init__(self, get_response): self.get_response = get_response # The case when the middleware is initialized in a Lambda Context, we make sure - # to use the ServerlessContext so that the middleware properly functions. - if check_in_lambda() is not None: - xray_recorder.context = ServerlessContext() + # to use the ServerlessLambdaContext so that the middleware properly functions. + # We also check if the current context is a LambdaContext to not override customer + # provided contexts. + if check_in_lambda() is not None and type(xray_recorder.context) == LambdaContext: + xray_recorder.context = ServerlessLambdaContext() # hooks for django version >= 1.10 def __call__(self, request): diff --git a/aws_xray_sdk/ext/flask/middleware.py b/aws_xray_sdk/ext/flask/middleware.py index 65661076..74ec0d9e 100644 --- a/aws_xray_sdk/ext/flask/middleware.py +++ b/aws_xray_sdk/ext/flask/middleware.py @@ -1,9 +1,9 @@ import flask.templating from flask import request -from aws_xray_sdk.core.lambda_launcher import check_in_lambda +from aws_xray_sdk.core.lambda_launcher import check_in_lambda, LambdaContext from aws_xray_sdk.core.models import http -from aws_xray_sdk.core.serverless_context import ServerlessContext +from aws_xray_sdk.core.serverless_lambda_context import ServerlessLambdaContext from aws_xray_sdk.core.utils import stacktrace from aws_xray_sdk.ext.util import calculate_sampling_decision, \ calculate_segment_name, construct_xray_header, prepare_response_header @@ -21,9 +21,11 @@ def __init__(self, app, recorder): self.app.teardown_request(self._handle_exception) # The case when the middleware is initialized in a Lambda Context, we make sure - # to use the ServerlessContext so that the middleware properly functions. - if check_in_lambda() is not None: - self._recorder.context = ServerlessContext() + # to use the ServerlessLambdaContext so that the middleware properly functions. + # We also check if the current context is a LambdaContext to not override customer + # provided contexts. + if check_in_lambda() is not None and type(self._recorder.context) == LambdaContext: + self._recorder.context = ServerlessLambdaContext() _patch_render(recorder) diff --git a/tests/test_mimic_segment.py b/tests/test_mimic_segment.py index f3320b74..2c7311a7 100644 --- a/tests/test_mimic_segment.py +++ b/tests/test_mimic_segment.py @@ -30,9 +30,17 @@ def test_ready(): def test_invalid_init(): with pytest.raises(MimicSegmentInvalidException): MimicSegment(facade_segment=None, original_segment=original_segment) + with pytest.raises(MimicSegmentInvalidException): MimicSegment(facade_segment=facade_segment, original_segment=None) + with pytest.raises(MimicSegmentInvalidException): MimicSegment(facade_segment=Subsegment("Test", "local", original_segment), original_segment=None) + with pytest.raises(MimicSegmentInvalidException): MimicSegment(facade_segment=None, original_segment=Subsegment("Test", "local", original_segment)) + with pytest.raises(MimicSegmentInvalidException): + MimicSegment(facade_segment=facade_segment, original_segment=Subsegment("Test", "local", original_segment)) + with pytest.raises(MimicSegmentInvalidException): + MimicSegment(facade_segment=original_segment, original_segment=facade_segment) + MimicSegment(facade_segment=facade_segment, original_segment=original_segment) def test_init_similar(): diff --git a/tests/test_serverless_context.py b/tests/test_serverless_lambda_context.py similarity index 98% rename from tests/test_serverless_context.py rename to tests/test_serverless_lambda_context.py index e591cb25..54a65b3a 100644 --- a/tests/test_serverless_context.py +++ b/tests/test_serverless_lambda_context.py @@ -1,8 +1,7 @@ import os import pytest -from aws_xray_sdk.core import serverless_context -from aws_xray_sdk.core import context +from aws_xray_sdk.core.serverless_lambda_context import ServerlessLambdaContext from aws_xray_sdk.core.lambda_launcher import LAMBDA_TRACE_HEADER_KEY from aws_xray_sdk.core.exceptions.exceptions import AlreadyEndedException, SegmentNotFoundException from aws_xray_sdk.core.models.segment import Segment @@ -16,7 +15,7 @@ HEADER_VAR = "Root=%s;Parent=%s;Sampled=1" % (TRACE_ID, PARENT_ID) os.environ[LAMBDA_TRACE_HEADER_KEY] = HEADER_VAR -context = serverless_context.ServerlessContext() +context = ServerlessLambdaContext() service_name = "Test Flask Server"