Skip to content

Commit

Permalink
Flask & Django create subsegments only if in LambdaContext & Lambda E…
Browse files Browse the repository at this point in the history
…nvironment (aws#138)

* Added unit tests to ensure these frameworks generate segmentby default and
subsegments if in lambda context and environment.
  • Loading branch information
chanchiem committed Feb 27, 2019
1 parent 741bb6c commit 512312d
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 18 deletions.
12 changes: 6 additions & 6 deletions aws_xray_sdk/ext/django/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aws_xray_sdk.core.utils import stacktrace
from aws_xray_sdk.ext.util import calculate_sampling_decision, \
calculate_segment_name, construct_xray_header, prepare_response_header
from aws_xray_sdk.core.lambda_launcher import check_in_lambda
from aws_xray_sdk.core.lambda_launcher import check_in_lambda, LambdaContext


log = logging.getLogger(__name__)
Expand All @@ -25,10 +25,10 @@ class XRayMiddleware(object):
def __init__(self, get_response):

self.get_response = get_response
self.in_lambda = False
self.in_lambda_ctx = False

if check_in_lambda():
self.in_lambda = True
if check_in_lambda() and type(xray_recorder.context) == LambdaContext:
self.in_lambda_ctx = True

# hooks for django version >= 1.10
def __call__(self, request):
Expand All @@ -51,7 +51,7 @@ def __call__(self, request):
sampling_req=sampling_req,
)

if self.in_lambda:
if self.in_lambda_ctx:
segment = xray_recorder.begin_subsegment(name)
else:
segment = xray_recorder.begin_segment(
Expand Down Expand Up @@ -83,7 +83,7 @@ def __call__(self, request):
segment.put_http_meta(http.CONTENT_LENGTH, length)
response[http.XRAY_HEADER] = prepare_response_header(xray_header, segment)

if self.in_lambda:
if self.in_lambda_ctx:
xray_recorder.end_subsegment()
else:
xray_recorder.end_segment()
Expand Down
18 changes: 9 additions & 9 deletions aws_xray_sdk/ext/flask/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aws_xray_sdk.core.utils import stacktrace
from aws_xray_sdk.ext.util import calculate_sampling_decision, \
calculate_segment_name, construct_xray_header, prepare_response_header
from aws_xray_sdk.core.lambda_launcher import check_in_lambda
from aws_xray_sdk.core.lambda_launcher import check_in_lambda, LambdaContext


class XRayMiddleware(object):
Expand All @@ -18,10 +18,10 @@ def __init__(self, app, recorder):
self.app.before_request(self._before_request)
self.app.after_request(self._after_request)
self.app.teardown_request(self._handle_exception)
self.in_lambda = False
self.in_lambda_ctx = False

if check_in_lambda():
self.in_lambda = True
if check_in_lambda() and type(self._recorder.context) == LambdaContext:
self.in_lambda_ctx = True

_patch_render(recorder)

Expand All @@ -44,7 +44,7 @@ def _before_request(self):
sampling_req=sampling_req,
)

if self.in_lambda:
if self.in_lambda_ctx:
segment = self._recorder.begin_subsegment(name)
else:
segment = self._recorder.begin_segment(
Expand All @@ -67,7 +67,7 @@ def _before_request(self):
segment.put_http_meta(http.CLIENT_IP, req.remote_addr)

def _after_request(self, response):
if self.in_lambda:
if self.in_lambda_ctx:
segment = self._recorder.current_subsegment()
else:
segment = self._recorder.current_segment()
Expand All @@ -81,7 +81,7 @@ def _after_request(self, response):
if cont_len:
segment.put_http_meta(http.CONTENT_LENGTH, int(cont_len))

if self.in_lambda:
if self.in_lambda_ctx:
self._recorder.end_subsegment()
else:
self._recorder.end_segment()
Expand All @@ -92,7 +92,7 @@ def _handle_exception(self, exception):
return
segment = None
try:
if self.in_lambda:
if self.in_lambda_ctx:
segment = self._recorder.current_subsegment()
else:
segment = self._recorder.current_segment()
Expand All @@ -104,7 +104,7 @@ def _handle_exception(self, exception):
segment.put_http_meta(http.STATUS, 500)
stack = stacktrace.get_stacktrace(limit=self._recorder._max_trace_back)
segment.add_exception(exception, stack)
if self.in_lambda:
if self.in_lambda_ctx:
self._recorder.end_subsegment()
else:
self._recorder.end_segment()
Expand Down
9 changes: 8 additions & 1 deletion tests/ext/django/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from aws_xray_sdk.core import xray_recorder, lambda_launcher
from aws_xray_sdk.core.context import Context
from aws_xray_sdk.core.models import http, facade_segment
from aws_xray_sdk.core.models import http, facade_segment, segment
from tests.util import get_new_stubbed_recorder
import os

Expand Down Expand Up @@ -132,3 +132,10 @@ def test_lambda_serverless(self):
self.client.get(url)
segment = new_recorder.emitter.pop()
assert not segment

def test_lambda_default_ctx(self):
# Track to make sure that Django will default to generating segments if context is not the lambda context
url = reverse('200ok')
self.client.get(url)
cur_segment = xray_recorder.emitter.pop()
assert type(cur_segment) == segment.Segment
32 changes: 30 additions & 2 deletions tests/ext/flask/test_flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aws_xray_sdk.ext.flask.middleware import XRayMiddleware
from aws_xray_sdk.core.context import Context
from aws_xray_sdk.core import lambda_launcher
from aws_xray_sdk.core.models import http, facade_segment
from aws_xray_sdk.core.models import http, facade_segment, segment
from tests.util import get_new_stubbed_recorder
import os

Expand Down Expand Up @@ -186,7 +186,7 @@ def trace_header():
return 'ok'

middleware = XRayMiddleware(new_app, new_recorder)
middleware.in_lambda = True
middleware.in_lambda_ctx = True

app_client = new_app.test_client()

Expand All @@ -197,3 +197,31 @@ def trace_header():

path2 = '/trace_header'
app_client.get(path2, headers={http.XRAY_HEADER: 'k1=v1'})


def test_lambda_default_ctx():
# Track to make sure that Flask will default to generating segments if context is not the lambda context
TRACE_ID = '1-5759e988-bd862e3fe1be46a994272793'
PARENT_ID = '53995c3f42cd8ad8'
HEADER_VAR = "Root=%s;Parent=%s;Sampled=1" % (TRACE_ID, PARENT_ID)

os.environ[lambda_launcher.LAMBDA_TRACE_HEADER_KEY] = HEADER_VAR

new_recorder = get_new_stubbed_recorder()
new_recorder.configure(service='test', sampling=False)
new_app = Flask(__name__)

@new_app.route('/segment')
def subsegment():
# Test in between request and make sure Lambda that uses default context generates a segment.
assert new_recorder.current_segment()
assert type(new_recorder.current_segment()) == segment.Segment
return 'ok'

XRayMiddleware(new_app, new_recorder)
app_client = new_app.test_client()

path = '/segment'
app_client.get(path)
segment = recorder.emitter.pop()
assert not segment # Segment should be none because it's created and ended by the middleware

0 comments on commit 512312d

Please sign in to comment.