Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref(client): Improve get_integration typing #3550

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions sentry_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Mapping
from datetime import datetime, timezone
from importlib import import_module
from typing import cast
from typing import cast, overload

from sentry_sdk._compat import PY37, check_uwsgi_thread_support
from sentry_sdk.utils import (
Expand Down Expand Up @@ -54,6 +54,7 @@
from typing import Sequence
from typing import Type
from typing import Union
from typing import TypeVar

from sentry_sdk._types import Event, Hint, SDKInfo
from sentry_sdk.integrations import Integration
Expand All @@ -62,6 +63,7 @@
from sentry_sdk.session import Session
from sentry_sdk.transport import Transport

I = TypeVar("I", bound=Integration) # noqa: E741

_client_init_debug = ContextVar("client_init_debug")

Expand Down Expand Up @@ -195,8 +197,20 @@ def capture_session(self, *args, **kwargs):
# type: (*Any, **Any) -> None
return None

def get_integration(self, *args, **kwargs):
# type: (*Any, **Any) -> Any
if TYPE_CHECKING:

@overload
def get_integration(self, name_or_class):
# type: (str) -> Optional[Integration]
...

@overload
def get_integration(self, name_or_class):
# type: (type[I]) -> Optional[I]
...

def get_integration(self, name_or_class):
# type: (Union[str, type[Integration]]) -> Optional[Integration]
return None

def close(self, *args, **kwargs):
Expand Down Expand Up @@ -815,10 +829,22 @@ def capture_session(
else:
self.session_flusher.add_session(session)

if TYPE_CHECKING:

@overload
def get_integration(self, name_or_class):
# type: (str) -> Optional[Integration]
...

@overload
def get_integration(self, name_or_class):
# type: (type[I]) -> Optional[I]
...

def get_integration(
self, name_or_class # type: Union[str, Type[Integration]]
):
# type: (...) -> Any
# type: (...) -> Optional[Integration]
"""Returns the integration for this client by name or class.
If the client does not have that integration then `None` is returned.
"""
Expand Down
4 changes: 4 additions & 0 deletions sentry_sdk/integrations/aiohttp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import weakref
from functools import wraps

import sentry_sdk
from sentry_sdk.api import continue_trace
Expand Down Expand Up @@ -146,11 +147,14 @@ async def sentry_app_handle(self, request, *args, **kwargs):

old_urldispatcher_resolve = UrlDispatcher.resolve

@wraps(old_urldispatcher_resolve)
async def sentry_urldispatcher_resolve(self, request):
# type: (UrlDispatcher, Request) -> UrlMappingMatchInfo
rv = await old_urldispatcher_resolve(self, request)

integration = sentry_sdk.get_client().get_integration(AioHttpIntegration)
if integration is None:
return rv

name = None

Expand Down
8 changes: 3 additions & 5 deletions sentry_sdk/integrations/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.utils import (
capture_internal_exceptions,
ensure_integration_enabled,
event_from_exception,
package_version,
)
Expand Down Expand Up @@ -78,10 +77,11 @@ def _calculate_token_usage(result, span):
def _wrap_message_create(f):
# type: (Any) -> Any
@wraps(f)
@ensure_integration_enabled(AnthropicIntegration, f)
def _sentry_patched_create(*args, **kwargs):
# type: (*Any, **Any) -> Any
if "messages" not in kwargs:
integration = sentry_sdk.get_client().get_integration(AnthropicIntegration)

if integration is None or "messages" not in kwargs:
return f(*args, **kwargs)

try:
Expand All @@ -106,8 +106,6 @@ def _sentry_patched_create(*args, **kwargs):
span.__exit__(None, None, None)
raise exc from None

integration = sentry_sdk.get_client().get_integration(AnthropicIntegration)

with capture_internal_exceptions():
span.set_data(SPANDATA.AI_MODEL_ID, model)
span.set_data(SPANDATA.AI_STREAMING, False)
Expand Down
5 changes: 4 additions & 1 deletion sentry_sdk/integrations/atexit.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,7 @@ def _shutdown():

logger.debug("atexit: shutting down client")
sentry_sdk.get_isolation_scope().end_session()
client.close(callback=integration.callback)

if integration is not None:
# Should not be None, but mypy doesn't know that
client.close(callback=integration.callback)
12 changes: 8 additions & 4 deletions sentry_sdk/integrations/aws_lambda.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import json
import re
import sys
Expand Down Expand Up @@ -70,7 +71,7 @@ def sentry_init_error(*args, **kwargs):

def _wrap_handler(handler):
# type: (F) -> F
@ensure_integration_enabled(AwsLambdaIntegration, handler)
@functools.wraps(handler)
def sentry_handler(aws_event, aws_context, *args, **kwargs):
# type: (Any, Any, *Any, **Any) -> Any

Expand All @@ -84,6 +85,12 @@ def sentry_handler(aws_event, aws_context, *args, **kwargs):
# will be the same for all events in the list, since they're all hitting
# the lambda in the same request.)

client = sentry_sdk.get_client()
integration = client.get_integration(AwsLambdaIntegration)

if integration is None:
return handler(aws_event, aws_context, *args, **kwargs)

if isinstance(aws_event, list) and len(aws_event) >= 1:
request_data = aws_event[0]
batch_size = len(aws_event)
Expand All @@ -97,9 +104,6 @@ def sentry_handler(aws_event, aws_context, *args, **kwargs):
# this is empty
request_data = {}

client = sentry_sdk.get_client()
integration = client.get_integration(AwsLambdaIntegration)

configured_time = aws_context.get_remaining_time_in_millis()

with sentry_sdk.isolation_scope() as scope:
Expand Down
6 changes: 5 additions & 1 deletion sentry_sdk/integrations/bottle.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import functools

import sentry_sdk
from sentry_sdk.tracing import SOURCE_FOR_STYLE
from sentry_sdk.utils import (
Expand Down Expand Up @@ -81,10 +83,12 @@ def sentry_patched_wsgi_app(self, environ, start_response):

old_handle = Bottle._handle

@ensure_integration_enabled(BottleIntegration, old_handle)
@functools.wraps(old_handle)
def _patched_handle(self, environ):
# type: (Bottle, Dict[str, Any]) -> Any
integration = sentry_sdk.get_client().get_integration(BottleIntegration)
if integration is None:
return old_handle(self, environ)

scope = sentry_sdk.get_isolation_scope()
scope._name = "bottle"
Expand Down
6 changes: 4 additions & 2 deletions sentry_sdk/integrations/celery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,15 @@ def __exit__(self, exc_type, exc_value, traceback):
def _wrap_task_run(f):
# type: (F) -> F
@wraps(f)
@ensure_integration_enabled(CeleryIntegration, f)
def apply_async(*args, **kwargs):
# type: (*Any, **Any) -> Any
# Note: kwargs can contain headers=None, so no setdefault!
# Unsure which backend though.
kwarg_headers = kwargs.get("headers") or {}
integration = sentry_sdk.get_client().get_integration(CeleryIntegration)
if integration is None:
return f(*args, **kwargs)

kwarg_headers = kwargs.get("headers") or {}
propagate_traces = kwarg_headers.pop(
"sentry-propagate-traces", integration.propagate_traces
)
Expand Down
24 changes: 11 additions & 13 deletions sentry_sdk/integrations/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@
import sentry_sdk
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.integrations import DidNotEnable, Integration
from sentry_sdk.utils import (
capture_internal_exceptions,
event_from_exception,
ensure_integration_enabled,
)
from sentry_sdk.utils import capture_internal_exceptions, event_from_exception

try:
from cohere.client import Client
Expand Down Expand Up @@ -134,13 +130,15 @@ def collect_chat_response_fields(span, res, include_pii):
set_data_normalized(span, "ai.warnings", res.meta.warnings)

@wraps(f)
@ensure_integration_enabled(CohereIntegration, f)
def new_chat(*args, **kwargs):
# type: (*Any, **Any) -> Any
if "message" not in kwargs:
return f(*args, **kwargs)
integration = sentry_sdk.get_client().get_integration(CohereIntegration)

if not isinstance(kwargs.get("message"), str):
if (
integration is None
or "message" not in kwargs
or not isinstance(kwargs.get("message"), str)
):
return f(*args, **kwargs)

message = kwargs.get("message")
Expand All @@ -158,8 +156,6 @@ def new_chat(*args, **kwargs):
span.__exit__(None, None, None)
raise e from None

integration = sentry_sdk.get_client().get_integration(CohereIntegration)

with capture_internal_exceptions():
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(
Expand Down Expand Up @@ -227,15 +223,17 @@ def _wrap_embed(f):
# type: (Callable[..., Any]) -> Callable[..., Any]

@wraps(f)
@ensure_integration_enabled(CohereIntegration, f)
def new_embed(*args, **kwargs):
# type: (*Any, **Any) -> Any
integration = sentry_sdk.get_client().get_integration(CohereIntegration)
if integration is None:
return f(*args, **kwargs)

with sentry_sdk.start_span(
op=consts.OP.COHERE_EMBEDDINGS_CREATE,
name="Cohere Embedding Creation",
origin=CohereIntegration.origin,
) as span:
integration = sentry_sdk.get_client().get_integration(CohereIntegration)
if "texts" in kwargs and (
should_send_default_pii() and integration.include_prompts
):
Expand Down
9 changes: 5 additions & 4 deletions sentry_sdk/integrations/django/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,11 @@ def _set_transaction_name_and_source(scope, transaction_style, request):
pass


@ensure_integration_enabled(DjangoIntegration)
def _before_get_response(request):
# type: (WSGIRequest) -> None
integration = sentry_sdk.get_client().get_integration(DjangoIntegration)
if integration is None:
return

_patch_drf()

Expand All @@ -440,11 +441,10 @@ def _attempt_resolve_again(request, scope, transaction_style):
_set_transaction_name_and_source(scope, transaction_style, request)


@ensure_integration_enabled(DjangoIntegration)
def _after_get_response(request):
# type: (WSGIRequest) -> None
integration = sentry_sdk.get_client().get_integration(DjangoIntegration)
if integration.transaction_style != "url":
if integration is None or integration.transaction_style != "url":
return

scope = sentry_sdk.get_current_scope()
Expand Down Expand Up @@ -510,11 +510,12 @@ def wsgi_request_event_processor(event, hint):
return wsgi_request_event_processor


@ensure_integration_enabled(DjangoIntegration)
def _got_request_exception(request=None, **kwargs):
# type: (WSGIRequest, **Any) -> None
client = sentry_sdk.get_client()
integration = client.get_integration(DjangoIntegration)
if integration is None:
return

if request is not None and integration.transaction_style == "url":
scope = sentry_sdk.get_current_scope()
Expand Down
4 changes: 2 additions & 2 deletions sentry_sdk/integrations/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ def _sentry_call(*args, **kwargs):

async def _sentry_app(*args, **kwargs):
# type: (*Any, **Any) -> Any
if sentry_sdk.get_client().get_integration(FastApiIntegration) is None:
integration = sentry_sdk.get_client().get_integration(FastApiIntegration)
if integration is None:
return await old_app(*args, **kwargs)

integration = sentry_sdk.get_client().get_integration(FastApiIntegration)
request = args[0]

_set_transaction_name_and_source(
Expand Down
4 changes: 3 additions & 1 deletion sentry_sdk/integrations/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,12 @@ def _set_transaction_name_and_source(scope, transaction_style, request):
pass


@ensure_integration_enabled(FlaskIntegration)
def _request_started(app, **kwargs):
# type: (Flask, **Any) -> None
integration = sentry_sdk.get_client().get_integration(FlaskIntegration)
if integration is None:
return

request = flask_request._get_current_object()

# Set the transaction name and source here,
Expand Down
6 changes: 4 additions & 2 deletions sentry_sdk/integrations/gcp.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import sys
from copy import deepcopy
from datetime import datetime, timedelta, timezone
Expand All @@ -13,7 +14,6 @@
from sentry_sdk.utils import (
AnnotatedValue,
capture_internal_exceptions,
ensure_integration_enabled,
event_from_exception,
logger,
TimeoutThread,
Expand All @@ -39,12 +39,14 @@

def _wrap_func(func):
# type: (F) -> F
@ensure_integration_enabled(GcpIntegration, func)
@functools.wraps(func)
def sentry_func(functionhandler, gcp_event, *args, **kwargs):
# type: (Any, Any, *Any, **Any) -> Any
client = sentry_sdk.get_client()

integration = client.get_integration(GcpIntegration)
if integration is None:
return func(functionhandler, gcp_event, *args, **kwargs)

configured_time = environ.get("FUNCTION_TIMEOUT_SEC")
if not configured_time:
Expand Down
8 changes: 4 additions & 4 deletions sentry_sdk/integrations/huggingface_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from sentry_sdk.utils import (
capture_internal_exceptions,
event_from_exception,
ensure_integration_enabled,
)

try:
Expand Down Expand Up @@ -55,9 +54,12 @@ def _capture_exception(exc):
def _wrap_text_generation(f):
# type: (Callable[..., Any]) -> Callable[..., Any]
@wraps(f)
@ensure_integration_enabled(HuggingfaceHubIntegration, f)
def new_text_generation(*args, **kwargs):
# type: (*Any, **Any) -> Any
integration = sentry_sdk.get_client().get_integration(HuggingfaceHubIntegration)
if integration is None:
return f(*args, **kwargs)

if "prompt" in kwargs:
prompt = kwargs["prompt"]
elif len(args) >= 2:
Expand All @@ -84,8 +86,6 @@ def new_text_generation(*args, **kwargs):
span.__exit__(None, None, None)
raise e from None

integration = sentry_sdk.get_client().get_integration(HuggingfaceHubIntegration)

with capture_internal_exceptions():
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, prompt)
Expand Down
Loading
Loading