diff --git a/CHANGELOG.asciidoc b/CHANGELOG.asciidoc index 2f66d885f..bf3f68751 100644 --- a/CHANGELOG.asciidoc +++ b/CHANGELOG.asciidoc @@ -37,6 +37,8 @@ endif::[] [float] ===== Features +* Add global access to Client singleton object at `elasticapm.get_client()` {pull}1043[#1043] + [float] ===== Bug fixes diff --git a/elasticapm/__init__.py b/elasticapm/__init__.py index 29e407934..42bbca423 100644 --- a/elasticapm/__init__.py +++ b/elasticapm/__init__.py @@ -29,7 +29,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE import sys -from elasticapm.base import Client +from elasticapm.base import Client, get_client # noqa: F401 from elasticapm.conf import setup_logging # noqa: F401 from elasticapm.instrumentation.control import instrument, uninstrument # noqa: F401 from elasticapm.traces import ( # noqa: F401 diff --git a/elasticapm/base.py b/elasticapm/base.py index 716de3c4e..cd3787a3a 100644 --- a/elasticapm/base.py +++ b/elasticapm/base.py @@ -54,6 +54,8 @@ __all__ = ("Client",) +CLIENT_SINGLETON = None + class Client(object): """ @@ -205,6 +207,9 @@ def __init__(self, config=None, **inline): if config.enabled: self.start_threads() + # Save this Client object as the global CLIENT_SINGLETON + set_client(self) + def start_threads(self): with self._thread_starter_lock: current_pid = os.getpid() @@ -298,6 +303,8 @@ def close(self): with self._thread_starter_lock: for _, manager in self._thread_managers.items(): manager.stop_thread() + global CLIENT_SINGLETON + CLIENT_SINGLETON = None def get_service_info(self): if self._service_info: @@ -611,3 +618,15 @@ class DummyClient(Client): def send(self, url, **kwargs): return None + + +def get_client(): + return CLIENT_SINGLETON + + +def set_client(client): + global CLIENT_SINGLETON + if CLIENT_SINGLETON: + logger = get_logger("elasticapm") + logger.debug("Client object is being set more than once") + CLIENT_SINGLETON = client diff --git a/elasticapm/contrib/aiohttp/__init__.py b/elasticapm/contrib/aiohttp/__init__.py index 75799caca..5588adfc4 100644 --- a/elasticapm/contrib/aiohttp/__init__.py +++ b/elasticapm/contrib/aiohttp/__init__.py @@ -33,8 +33,6 @@ import elasticapm from elasticapm import Client -CLIENT_KEY = "_elasticapm_client_instance" - class ElasticAPM: def __init__(self, app, client=None): @@ -43,7 +41,6 @@ def __init__(self, app, client=None): config.setdefault("framework_name", "aiohttp") config.setdefault("framework_version", aiohttp.__version__) client = Client(config=config) - app[CLIENT_KEY] = client self.app = app self.client = client self.install_tracing(app, client) diff --git a/elasticapm/contrib/aiohttp/middleware.py b/elasticapm/contrib/aiohttp/middleware.py index d83dbe9d6..20ea0309e 100644 --- a/elasticapm/contrib/aiohttp/middleware.py +++ b/elasticapm/contrib/aiohttp/middleware.py @@ -32,6 +32,7 @@ from aiohttp.web import HTTPException, Response, middleware import elasticapm +from elasticapm import get_client from elasticapm.conf import constants from elasticapm.contrib.aiohttp.utils import get_data_from_request, get_data_from_response from elasticapm.utils.disttracing import TraceParent @@ -44,13 +45,10 @@ def merge_duplicate_headers(cls, headers, key): def tracing_middleware(app): - from elasticapm.contrib.aiohttp import CLIENT_KEY # noqa - async def handle_request(request, handler): - elasticapm_client = app.get(CLIENT_KEY) + elasticapm_client = get_client() should_trace = elasticapm_client and not elasticapm_client.should_ignore_url(request.path) if should_trace: - request[CLIENT_KEY] = elasticapm_client trace_parent = AioHttpTraceParent.from_headers(request.headers) elasticapm_client.begin_transaction("request", trace_parent=trace_parent) resource = request.match_info.route.resource diff --git a/elasticapm/contrib/django/client.py b/elasticapm/contrib/django/client.py index fe84f8368..901aa67cb 100644 --- a/elasticapm/contrib/django/client.py +++ b/elasticapm/contrib/django/client.py @@ -37,6 +37,7 @@ from django.db import DatabaseError from django.http import HttpRequest +from elasticapm import get_client as _get_client from elasticapm.base import Client from elasticapm.conf import constants from elasticapm.contrib.django.utils import iterate_with_template_sources @@ -49,10 +50,9 @@ default_client_class = "elasticapm.contrib.django.DjangoClient" -_client = (None, None) -def get_client(client=None): +def get_client(): """ Get an ElasticAPM client. @@ -60,20 +60,16 @@ def get_client(client=None): :return: :rtype: elasticapm.base.Client """ - global _client - - tmp_client = client is not None - if not tmp_client: - config = getattr(django_settings, "ELASTIC_APM", {}) - client = config.get("CLIENT", default_client_class) - - if _client[0] != client: - client_class = import_string(client) - instance = client_class() - if not tmp_client: - _client = (client, instance) - return instance - return _client[1] + if _get_client(): + return _get_client() + + config = getattr(django_settings, "ELASTIC_APM", {}) + client = config.get("CLIENT", default_client_class) + client_class = import_string(client) + instance = client_class() + # `instance` will already be in elasticapm.base.CLIENT_SINGLETON due to the + # `__init__()` for Client + return instance class DjangoClient(Client): diff --git a/elasticapm/contrib/flask/__init__.py b/elasticapm/contrib/flask/__init__.py index e8fb5222f..361dca01c 100644 --- a/elasticapm/contrib/flask/__init__.py +++ b/elasticapm/contrib/flask/__init__.py @@ -38,6 +38,7 @@ import elasticapm import elasticapm.instrumentation.control +from elasticapm import get_client from elasticapm.base import Client from elasticapm.conf import constants, setup_logging from elasticapm.contrib.flask.utils import get_data_from_request, get_data_from_response @@ -50,17 +51,6 @@ logger = get_logger("elasticapm.errors.client") -def make_client(client_cls, app, **defaults): - config = app.config.get("ELASTIC_APM", {}) - - if "framework_name" not in defaults: - defaults["framework_name"] = "flask" - defaults["framework_version"] = getattr(flask, "__version__", "<0.7") - - client = client_cls(config, **defaults) - return client - - class ElasticAPM(object): """ Flask application for Elastic APM. @@ -97,10 +87,18 @@ class ElasticAPM(object): def __init__(self, app=None, client=None, client_cls=Client, logging=False, **defaults): self.app = app self.logging = logging - self.client_cls = client_cls - self.client = client + self.client = client or get_client() if app: + if not self.client: + config = app.config.get("ELASTIC_APM", {}) + + if "framework_name" not in defaults: + defaults["framework_name"] = "flask" + defaults["framework_version"] = getattr(flask, "__version__", "<0.7") + + self.client = client_cls(config, **defaults) + self.init_app(app, **defaults) def handle_exception(self, *args, **kwargs): @@ -125,8 +123,6 @@ def handle_exception(self, *args, **kwargs): def init_app(self, app, **defaults): self.app = app - if not self.client: - self.client = make_client(self.client_cls, app, **defaults) # 0 is a valid log level (NOTSET), so we need to check explicitly for it if self.logging or self.logging is logging.NOTSET: diff --git a/elasticapm/contrib/opentracing/tracer.py b/elasticapm/contrib/opentracing/tracer.py index cd65e66df..d51ee144e 100644 --- a/elasticapm/contrib/opentracing/tracer.py +++ b/elasticapm/contrib/opentracing/tracer.py @@ -36,17 +36,15 @@ from opentracing.tracer import Tracer as TracerBase import elasticapm -from elasticapm import instrument, traces +from elasticapm import get_client, instrument, traces from elasticapm.conf import constants from elasticapm.contrib.opentracing.span import OTSpan, OTSpanContext from elasticapm.utils import compat, disttracing class Tracer(TracerBase): - _elasticapm_client_class = elasticapm.Client - def __init__(self, client_instance=None, config=None, scope_manager=None): - self._agent = client_instance or self._elasticapm_client_class(config=config) + self._agent = client_instance or get_client() or elasticapm.Client(config=config) if scope_manager and not isinstance(scope_manager, ThreadLocalScopeManager): warnings.warn( "Currently, the Elastic APM opentracing bridge only supports the ThreadLocalScopeManager. " diff --git a/elasticapm/contrib/paste.py b/elasticapm/contrib/paste.py index f32564036..0c9a35231 100644 --- a/elasticapm/contrib/paste.py +++ b/elasticapm/contrib/paste.py @@ -29,10 +29,11 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from elasticapm import get_client from elasticapm.base import Client from elasticapm.middleware import ElasticAPM def filter_factory(app, global_conf, **kwargs): - client = Client(**kwargs) + client = get_client() or Client(**kwargs) return ElasticAPM(app, client) diff --git a/elasticapm/contrib/pylons/__init__.py b/elasticapm/contrib/pylons/__init__.py index 55a35916d..1cef9c67f 100644 --- a/elasticapm/contrib/pylons/__init__.py +++ b/elasticapm/contrib/pylons/__init__.py @@ -29,6 +29,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +from elasticapm import get_client from elasticapm.base import Client from elasticapm.middleware import ElasticAPM as Middleware from elasticapm.utils import compat @@ -44,5 +45,5 @@ def list_from_setting(config, setting): class ElasticAPM(Middleware): def __init__(self, app, config, client_cls=Client): client_config = {key[11:]: val for key, val in compat.iteritems(config) if key.startswith("elasticapm.")} - client = client_cls(**client_config) + client = get_client() or client_cls(**client_config) super(ElasticAPM, self).__init__(app, client) diff --git a/tests/contrib/django/django_tests.py b/tests/contrib/django/django_tests.py index d4ed6dcfd..a043b30b8 100644 --- a/tests/contrib/django/django_tests.py +++ b/tests/contrib/django/django_tests.py @@ -511,17 +511,6 @@ def test_response_error_id_middleware(django_elasticapm_client, client): assert event["id"] == headers["X-ElasticAPM-ErrorId"] -def test_get_client(django_elasticapm_client): - with mock.patch.dict("os.environ", {"ELASTIC_APM_METRICS_INTERVAL": "0ms"}): - client2 = get_client("elasticapm.base.Client") - try: - assert get_client() is get_client() - assert client2.__class__ == Client - finally: - get_client().close() - client2.close() - - @pytest.mark.parametrize("django_elasticapm_client", [{"capture_body": "errors"}], indirect=True) def test_raw_post_data_partial_read(django_elasticapm_client): v = compat.b('{"foo": "bar"}')