diff --git a/logging/google/cloud/logging/handlers/_helpers.py b/logging/google/cloud/logging/handlers/_helpers.py index e5bdabe2e388..8eb2434a27a0 100644 --- a/logging/google/cloud/logging/handlers/_helpers.py +++ b/logging/google/cloud/logging/handlers/_helpers.py @@ -19,10 +19,11 @@ try: import flask -except ImportError: +except ImportError: # pragma: NO COVER flask = None -from google.cloud.logging.handlers.middleware.request import _get_django_request +from google.cloud.logging.handlers.middleware.request import ( + _get_django_request) _FLASK_TRACE_HEADER = 'X_CLOUD_TRACE_CONTEXT' _DJANGO_TRACE_HEADER = 'HTTP_X_CLOUD_TRACE_CONTEXT' @@ -79,13 +80,10 @@ def get_trace_id_from_django(): if request is None: return None - try: - header = request.META.get(_DJANGO_TRACE_HEADER) - except KeyError: - return None - + header = request.META.get(_DJANGO_TRACE_HEADER) if header is None: return None + trace_id = header.split('/')[0] return trace_id diff --git a/logging/google/cloud/logging/handlers/app_engine.py b/logging/google/cloud/logging/handlers/app_engine.py index 4561b83555c4..509bf8002fb1 100644 --- a/logging/google/cloud/logging/handlers/app_engine.py +++ b/logging/google/cloud/logging/handlers/app_engine.py @@ -73,7 +73,10 @@ def get_gae_resource(self): return gae_resource def get_gae_labels(self): - """Return the labels for GAE app which includes trace_id. + """Return the labels for GAE app. + + If the trace ID can be detected, it will be included as a label. + Currently, no other labels are included. :rtype: dict :returns: Labels for GAE app. diff --git a/logging/google/cloud/logging/handlers/middleware/request.py b/logging/google/cloud/logging/handlers/middleware/request.py index 1acf7ffd9961..4c0b22a8e96b 100644 --- a/logging/google/cloud/logging/handlers/middleware/request.py +++ b/logging/google/cloud/logging/handlers/middleware/request.py @@ -12,6 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Django middleware helper to capture a request. + +The request is stored on a thread-local so that it can be +inspected by other helpers. +""" + import threading diff --git a/logging/tests/unit/handlers/middleware/test_request.py b/logging/tests/unit/handlers/middleware/test_request.py index ab8c18d1aa0e..983d67129647 100644 --- a/logging/tests/unit/handlers/middleware/test_request.py +++ b/logging/tests/unit/handlers/middleware/test_request.py @@ -14,16 +14,10 @@ import unittest +import mock -class TestRequestMiddleware(unittest.TestCase): - def _get_target_class(self): - from google.cloud.logging.handlers.middleware.request import RequestMiddleware - - return RequestMiddleware - - def _make_one(self, *args, **kw): - return self._get_target_class()(*args, **kw) +class DjangoBase(unittest.TestCase): @classmethod def setUpClass(cls): @@ -40,11 +34,53 @@ def tearDownClass(cls): teardown_test_environment() - def test_get_django_request(self): + +class TestRequestMiddleware(DjangoBase): + + def _get_target_class(self): + from google.cloud.logging.handlers.middleware import request + + return request.RequestMiddleware + + def _make_one(self, *args, **kw): + return self._get_target_class()(*args, **kw) + + def test_process_request(self): from django.test import RequestFactory - from google.cloud.logging.handlers.middleware.request import _get_django_request + from google.cloud.logging.handlers.middleware import request middleware = self._make_one() - request = RequestFactory().get('/') - middleware.process_request(request) - self.assertEqual(_get_django_request(), request) + mock_request = RequestFactory().get('/') + middleware.process_request(mock_request) + + django_request = request._get_django_request() + self.assertEqual(django_request, mock_request) + + +class Test__get_django_request(DjangoBase): + + @staticmethod + def _call_fut(): + from google.cloud.logging.handlers.middleware import request + + return request._get_django_request() + + @staticmethod + def _make_patch(new_locals): + return mock.patch( + 'google.cloud.logging.handlers.middleware.request._thread_locals', + new=new_locals) + + def test_with_request(self): + thread_locals = mock.Mock(spec=['request']) + with self._make_patch(thread_locals): + django_request = self._call_fut() + + self.assertIs(django_request, thread_locals.request) + + def test_without_request(self): + thread_locals = mock.Mock(spec=[]) + with self._make_patch(thread_locals): + django_request = self._call_fut() + + self.assertIsNone(django_request) diff --git a/logging/tests/unit/handlers/test__helpers.py b/logging/tests/unit/handlers/test__helpers.py index c50ce0044e85..a69e765afa98 100644 --- a/logging/tests/unit/handlers/test__helpers.py +++ b/logging/tests/unit/handlers/test__helpers.py @@ -14,9 +14,17 @@ import unittest +import mock + class Test_get_trace_id_from_flask(unittest.TestCase): + @staticmethod + def _call_fut(): + from google.cloud.logging.handlers import _helpers + + return _helpers.get_trace_id_from_flask() + @staticmethod def create_app(): import flask @@ -30,24 +38,19 @@ def index(): return app def setUp(self): - self.app = Test_get_trace_id_from_flask.create_app() + self.app = self.create_app() - def test_trace_id_no_context_header(self): + def test_no_context_header(self): from google.cloud.logging.handlers import _helpers with self.app.test_request_context( path='/', headers={}): - trace_id = _helpers.get_trace_id_from_flask() - trace_id_returned = _helpers.get_trace_id() + trace_id = self._call_fut() self.assertIsNone(trace_id) - self.assertIsNone(trace_id_returned) - - def test_trace_id_valid_context_header(self): - from google.cloud.logging.handlers._helpers import get_trace_id_from_flask - from google.cloud.logging.handlers._helpers import get_trace_id + def test_valid_context_header(self): flask_trace_header = 'X_CLOUD_TRACE_CONTEXT' expected_trace_id = 'testtraceidflask' flask_trace_id = expected_trace_id + '/testspanid' @@ -57,15 +60,19 @@ def test_trace_id_valid_context_header(self): headers={flask_trace_header: flask_trace_id}) with context: - trace_id = get_trace_id_from_flask() - trace_id_returned = get_trace_id() + trace_id = self._call_fut() self.assertEqual(trace_id, expected_trace_id) - self.assertEqual(trace_id, trace_id_returned) class Test_get_trace_id_from_django(unittest.TestCase): + @staticmethod + def _call_fut(): + from google.cloud.logging.handlers import _helpers + + return _helpers.get_trace_id_from_django() + def setUp(self): from django.conf import settings from django.test.utils import setup_test_environment @@ -76,47 +83,91 @@ def setUp(self): def tearDown(self): from django.test.utils import teardown_test_environment + from google.cloud.logging.handlers.middleware import request teardown_test_environment() + request._thread_locals.__dict__.clear() - def test_trace_id_no_context_header(self): + def test_no_context_header(self): from django.test import RequestFactory - from google.cloud.logging.handlers import _helpers from google.cloud.logging.handlers.middleware import request django_request = RequestFactory().get('/') middleware = request.RequestMiddleware() middleware.process_request(django_request) - trace_id = _helpers.get_trace_id_from_django() - trace_id_returned = _helpers.get_trace_id() - + trace_id = self._call_fut() self.assertIsNone(trace_id) - self.assertIsNone(trace_id_returned) - request._thread_locals.__dict__.clear() - - def test_trace_id_valid_context_header(self): + def test_valid_context_header(self): from django.test import RequestFactory - from google.cloud.logging.handlers._helpers import get_trace_id_from_django - from google.cloud.logging.handlers._helpers import get_trace_id - from google.cloud.logging.handlers.middleware.request import RequestMiddleware - from google.cloud.logging.handlers.middleware.request import _thread_locals + from google.cloud.logging.handlers.middleware import request django_trace_header = 'HTTP_X_CLOUD_TRACE_CONTEXT' expected_trace_id = 'testtraceiddjango' django_trace_id = expected_trace_id + '/testspanid' - request = RequestFactory().get( + django_request = RequestFactory().get( '/', **{django_trace_header: django_trace_id}) - middleware = RequestMiddleware() - middleware.process_request(request) - trace_id = get_trace_id_from_django() - trace_id_returned = get_trace_id() + middleware = request.RequestMiddleware() + middleware.process_request(django_request) + trace_id = self._call_fut() self.assertEqual(trace_id, expected_trace_id) - self.assertEqual(trace_id, trace_id_returned) - _thread_locals.__dict__.clear() + +class Test_get_trace_id(unittest.TestCase): + + @staticmethod + def _call_fut(): + from google.cloud.logging.handlers import _helpers + + return _helpers.get_trace_id() + + def _helper(self, django_return, flask_return): + django_patch = mock.patch( + 'google.cloud.logging.handlers._helpers.get_trace_id_from_django', + return_value=django_return) + flask_patch = mock.patch( + 'google.cloud.logging.handlers._helpers.get_trace_id_from_flask', + return_value=flask_return) + + with django_patch as django_mock: + with flask_patch as flask_mock: + trace_id = self._call_fut() + + return django_mock, flask_mock, trace_id + + def test_from_django(self): + django_mock, flask_mock, trace_id = self._helper( + 'test-django-trace-id', None) + self.assertEqual(trace_id, django_mock.return_value) + + django_mock.assert_called_once_with() + flask_mock.assert_not_called() + + def test_from_flask(self): + django_mock, flask_mock, trace_id = self._helper( + None, 'test-flask-trace-id') + self.assertEqual(trace_id, flask_mock.return_value) + + django_mock.assert_called_once_with() + flask_mock.assert_called_once_with() + + def test_from_django_and_flask(self): + django_mock, flask_mock, trace_id = self._helper( + 'test-django-trace-id', 'test-flask-trace-id') + # Django wins. + self.assertEqual(trace_id, django_mock.return_value) + + django_mock.assert_called_once_with() + flask_mock.assert_not_called() + + def test_missing(self): + django_mock, flask_mock, trace_id = self._helper(None, None) + self.assertIsNone(trace_id) + + django_mock.assert_called_once_with() + flask_mock.assert_called_once_with() diff --git a/logging/tests/unit/handlers/test_app_engine.py b/logging/tests/unit/handlers/test_app_engine.py index ccdaf2992335..6438c4abb8a0 100644 --- a/logging/tests/unit/handlers/test_app_engine.py +++ b/logging/tests/unit/handlers/test_app_engine.py @@ -15,8 +15,10 @@ import logging import unittest +import mock -class TestAppEngineHandlerHandler(unittest.TestCase): + +class TestAppEngineHandler(unittest.TestCase): PROJECT = 'PROJECT' def _get_target_class(self): @@ -28,7 +30,6 @@ def _make_one(self, *args, **kw): return self._get_target_class()(*args, **kw) def test_constructor(self): - import mock from google.cloud.logging.handlers.app_engine import _GAE_PROJECT_ENV from google.cloud.logging.handlers.app_engine import _GAE_SERVICE_ENV from google.cloud.logging.handlers.app_engine import _GAE_VERSION_ENV @@ -48,8 +49,6 @@ def test_constructor(self): self.assertEqual(handler.labels, {}) def test_emit(self): - import mock - client = mock.Mock(project=self.PROJECT, spec=['project']) handler = self._make_one(client, transport=_Transport) gae_resource = handler.get_gae_resource() @@ -66,6 +65,35 @@ def test_emit(self): handler.transport.send_called_with, (record, message, gae_resource, gae_labels)) + def _get_gae_labels_helper(self, trace_id): + get_trace_patch = mock.patch( + 'google.cloud.logging.handlers.app_engine.get_trace_id', + return_value=trace_id) + + client = mock.Mock(project=self.PROJECT, spec=['project']) + # The handler actually calls ``get_gae_labels()``. + with get_trace_patch as mock_get_trace: + handler = self._make_one(client, transport=_Transport) + mock_get_trace.assert_called_once_with() + + gae_labels = handler.get_gae_labels() + self.assertEqual(mock_get_trace.mock_calls, + [mock.call(), mock.call()]) + + return gae_labels + + def test_get_gae_labels_with_label(self): + from google.cloud.logging.handlers import app_engine + + trace_id = 'test-gae-trace-id' + gae_labels = self._get_gae_labels_helper(trace_id) + expected_labels = {app_engine._TRACE_ID_LABEL: trace_id} + self.assertEqual(gae_labels, expected_labels) + + def test_get_gae_labels_without_label(self): + gae_labels = self._get_gae_labels_helper(None) + self.assertEqual(gae_labels, {}) + class _Transport(object):