Skip to content

Commit

Permalink
ref(client): Improve get_integration typing
Browse files Browse the repository at this point in the history
Improve `get_integration` typing to make it clear that we return an `Optional[Integration]`. Further, add overloads to specify that when called with some integration type `I` (i.e. `I` is a subclass of `Integration`), then `get_integration` guarantees a return value of `Optional[I]`.

These changes should enhance type safety by explicitly guaranteeing the existing behavior of `get_integration`.
  • Loading branch information
szokeasaurusrex committed Sep 20, 2024
1 parent 64e2977 commit b3678cb
Show file tree
Hide file tree
Showing 21 changed files with 130 additions and 77 deletions.
34 changes: 30 additions & 4 deletions sentry_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Mapping
from datetime import datetime, timezone
from importlib import import_module
from typing import cast
from typing import cast, overload

from sentry_sdk._compat import PY37, check_uwsgi_thread_support
from sentry_sdk.utils import (
Expand Down Expand Up @@ -54,6 +54,7 @@
from typing import Sequence
from typing import Type
from typing import Union
from typing import TypeVar

from sentry_sdk._types import Event, Hint, SDKInfo
from sentry_sdk.integrations import Integration
Expand All @@ -62,6 +63,7 @@
from sentry_sdk.session import Session
from sentry_sdk.transport import Transport

I = TypeVar("I", bound=Integration) # noqa: E741

_client_init_debug = ContextVar("client_init_debug")

Expand Down Expand Up @@ -195,8 +197,20 @@ def capture_session(self, *args, **kwargs):
# type: (*Any, **Any) -> None
return None

def get_integration(self, *args, **kwargs):
# type: (*Any, **Any) -> Any
if TYPE_CHECKING:

@overload
def get_integration(self, name_or_class):
# type: (str) -> Optional[Integration]
...

@overload
def get_integration(self, name_or_class):
# type: (type[I]) -> Optional[I]
...

def get_integration(self, name_or_class):
# type: (Union[str, type[Integration]]) -> Optional[Integration]
return None

def close(self, *args, **kwargs):
Expand Down Expand Up @@ -815,10 +829,22 @@ def capture_session(
else:
self.session_flusher.add_session(session)

if TYPE_CHECKING:

@overload
def get_integration(self, name_or_class):
# type: (str) -> Optional[Integration]
...

@overload
def get_integration(self, name_or_class):
# type: (type[I]) -> Optional[I]
...

def get_integration(
self, name_or_class # type: Union[str, Type[Integration]]
):
# type: (...) -> Any
# type: (...) -> Optional[Integration]
"""Returns the integration for this client by name or class.
If the client does not have that integration then `None` is returned.
"""
Expand Down
4 changes: 4 additions & 0 deletions sentry_sdk/integrations/aiohttp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import weakref
from functools import wraps

import sentry_sdk
from sentry_sdk.api import continue_trace
Expand Down Expand Up @@ -146,11 +147,14 @@ async def sentry_app_handle(self, request, *args, **kwargs):

old_urldispatcher_resolve = UrlDispatcher.resolve

@wraps(old_urldispatcher_resolve)
async def sentry_urldispatcher_resolve(self, request):
# type: (UrlDispatcher, Request) -> UrlMappingMatchInfo
rv = await old_urldispatcher_resolve(self, request)

integration = sentry_sdk.get_client().get_integration(AioHttpIntegration)
if integration is None:
return rv

name = None

Expand Down
8 changes: 3 additions & 5 deletions sentry_sdk/integrations/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.utils import (
capture_internal_exceptions,
ensure_integration_enabled,
event_from_exception,
package_version,
)
Expand Down Expand Up @@ -78,10 +77,11 @@ def _calculate_token_usage(result, span):
def _wrap_message_create(f):
# type: (Any) -> Any
@wraps(f)
@ensure_integration_enabled(AnthropicIntegration, f)
def _sentry_patched_create(*args, **kwargs):
# type: (*Any, **Any) -> Any
if "messages" not in kwargs:
integration = sentry_sdk.get_client().get_integration(AnthropicIntegration)

if integration is None or "messages" not in kwargs:
return f(*args, **kwargs)

try:
Expand All @@ -106,8 +106,6 @@ def _sentry_patched_create(*args, **kwargs):
span.__exit__(None, None, None)
raise exc from None

integration = sentry_sdk.get_client().get_integration(AnthropicIntegration)

with capture_internal_exceptions():
span.set_data(SPANDATA.AI_MODEL_ID, model)
span.set_data(SPANDATA.AI_STREAMING, False)
Expand Down
5 changes: 4 additions & 1 deletion sentry_sdk/integrations/atexit.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,7 @@ def _shutdown():

logger.debug("atexit: shutting down client")
sentry_sdk.get_isolation_scope().end_session()
client.close(callback=integration.callback)

if integration is not None:
# Should not be None, but mypy doesn't know that
client.close(callback=integration.callback)
12 changes: 8 additions & 4 deletions sentry_sdk/integrations/aws_lambda.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import json
import re
import sys
Expand Down Expand Up @@ -70,7 +71,7 @@ def sentry_init_error(*args, **kwargs):

def _wrap_handler(handler):
# type: (F) -> F
@ensure_integration_enabled(AwsLambdaIntegration, handler)
@functools.wraps(handler)
def sentry_handler(aws_event, aws_context, *args, **kwargs):
# type: (Any, Any, *Any, **Any) -> Any

Expand All @@ -84,6 +85,12 @@ def sentry_handler(aws_event, aws_context, *args, **kwargs):
# will be the same for all events in the list, since they're all hitting
# the lambda in the same request.)

client = sentry_sdk.get_client()
integration = client.get_integration(AwsLambdaIntegration)

if integration is None:
return handler(aws_event, aws_context, *args, **kwargs)

if isinstance(aws_event, list) and len(aws_event) >= 1:
request_data = aws_event[0]
batch_size = len(aws_event)
Expand All @@ -97,9 +104,6 @@ def sentry_handler(aws_event, aws_context, *args, **kwargs):
# this is empty
request_data = {}

client = sentry_sdk.get_client()
integration = client.get_integration(AwsLambdaIntegration)

configured_time = aws_context.get_remaining_time_in_millis()

with sentry_sdk.isolation_scope() as scope:
Expand Down
6 changes: 5 additions & 1 deletion sentry_sdk/integrations/bottle.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import functools

import sentry_sdk
from sentry_sdk.tracing import SOURCE_FOR_STYLE
from sentry_sdk.utils import (
Expand Down Expand Up @@ -81,10 +83,12 @@ def sentry_patched_wsgi_app(self, environ, start_response):

old_handle = Bottle._handle

@ensure_integration_enabled(BottleIntegration, old_handle)
@functools.wraps(old_handle)
def _patched_handle(self, environ):
# type: (Bottle, Dict[str, Any]) -> Any
integration = sentry_sdk.get_client().get_integration(BottleIntegration)
if integration is None:
return old_handle(self, environ)

scope = sentry_sdk.get_isolation_scope()
scope._name = "bottle"
Expand Down
6 changes: 4 additions & 2 deletions sentry_sdk/integrations/celery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,15 @@ def __exit__(self, exc_type, exc_value, traceback):
def _wrap_task_run(f):
# type: (F) -> F
@wraps(f)
@ensure_integration_enabled(CeleryIntegration, f)
def apply_async(*args, **kwargs):
# type: (*Any, **Any) -> Any
# Note: kwargs can contain headers=None, so no setdefault!
# Unsure which backend though.
kwarg_headers = kwargs.get("headers") or {}
integration = sentry_sdk.get_client().get_integration(CeleryIntegration)
if integration is None:
return f(*args, **kwargs)

kwarg_headers = kwargs.get("headers") or {}
propagate_traces = kwarg_headers.pop(
"sentry-propagate-traces", integration.propagate_traces
)
Expand Down
24 changes: 11 additions & 13 deletions sentry_sdk/integrations/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@
import sentry_sdk
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.integrations import DidNotEnable, Integration
from sentry_sdk.utils import (
capture_internal_exceptions,
event_from_exception,
ensure_integration_enabled,
)
from sentry_sdk.utils import capture_internal_exceptions, event_from_exception

try:
from cohere.client import Client
Expand Down Expand Up @@ -134,13 +130,15 @@ def collect_chat_response_fields(span, res, include_pii):
set_data_normalized(span, "ai.warnings", res.meta.warnings)

@wraps(f)
@ensure_integration_enabled(CohereIntegration, f)
def new_chat(*args, **kwargs):
# type: (*Any, **Any) -> Any
if "message" not in kwargs:
return f(*args, **kwargs)
integration = sentry_sdk.get_client().get_integration(CohereIntegration)

if not isinstance(kwargs.get("message"), str):
if (
integration is None
or "message" not in kwargs
or not isinstance(kwargs.get("message"), str)
):
return f(*args, **kwargs)

message = kwargs.get("message")
Expand All @@ -158,8 +156,6 @@ def new_chat(*args, **kwargs):
span.__exit__(None, None, None)
raise e from None

integration = sentry_sdk.get_client().get_integration(CohereIntegration)

with capture_internal_exceptions():
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(
Expand Down Expand Up @@ -227,15 +223,17 @@ def _wrap_embed(f):
# type: (Callable[..., Any]) -> Callable[..., Any]

@wraps(f)
@ensure_integration_enabled(CohereIntegration, f)
def new_embed(*args, **kwargs):
# type: (*Any, **Any) -> Any
integration = sentry_sdk.get_client().get_integration(CohereIntegration)
if integration is None:
return f(*args, **kwargs)

with sentry_sdk.start_span(
op=consts.OP.COHERE_EMBEDDINGS_CREATE,
name="Cohere Embedding Creation",
origin=CohereIntegration.origin,
) as span:
integration = sentry_sdk.get_client().get_integration(CohereIntegration)
if "texts" in kwargs and (
should_send_default_pii() and integration.include_prompts
):
Expand Down
9 changes: 5 additions & 4 deletions sentry_sdk/integrations/django/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,11 @@ def _set_transaction_name_and_source(scope, transaction_style, request):
pass


@ensure_integration_enabled(DjangoIntegration)
def _before_get_response(request):
# type: (WSGIRequest) -> None
integration = sentry_sdk.get_client().get_integration(DjangoIntegration)
if integration is None:
return

_patch_drf()

Expand All @@ -440,11 +441,10 @@ def _attempt_resolve_again(request, scope, transaction_style):
_set_transaction_name_and_source(scope, transaction_style, request)


@ensure_integration_enabled(DjangoIntegration)
def _after_get_response(request):
# type: (WSGIRequest) -> None
integration = sentry_sdk.get_client().get_integration(DjangoIntegration)
if integration.transaction_style != "url":
if integration is None or integration.transaction_style != "url":
return

scope = sentry_sdk.get_current_scope()
Expand Down Expand Up @@ -510,11 +510,12 @@ def wsgi_request_event_processor(event, hint):
return wsgi_request_event_processor


@ensure_integration_enabled(DjangoIntegration)
def _got_request_exception(request=None, **kwargs):
# type: (WSGIRequest, **Any) -> None
client = sentry_sdk.get_client()
integration = client.get_integration(DjangoIntegration)
if integration is None:
return

if request is not None and integration.transaction_style == "url":
scope = sentry_sdk.get_current_scope()
Expand Down
4 changes: 2 additions & 2 deletions sentry_sdk/integrations/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ def _sentry_call(*args, **kwargs):

async def _sentry_app(*args, **kwargs):
# type: (*Any, **Any) -> Any
if sentry_sdk.get_client().get_integration(FastApiIntegration) is None:
integration = sentry_sdk.get_client().get_integration(FastApiIntegration)
if integration is None:
return await old_app(*args, **kwargs)

integration = sentry_sdk.get_client().get_integration(FastApiIntegration)
request = args[0]

_set_transaction_name_and_source(
Expand Down
4 changes: 3 additions & 1 deletion sentry_sdk/integrations/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,12 @@ def _set_transaction_name_and_source(scope, transaction_style, request):
pass


@ensure_integration_enabled(FlaskIntegration)
def _request_started(app, **kwargs):
# type: (Flask, **Any) -> None
integration = sentry_sdk.get_client().get_integration(FlaskIntegration)
if integration is None:
return

request = flask_request._get_current_object()

# Set the transaction name and source here,
Expand Down
6 changes: 4 additions & 2 deletions sentry_sdk/integrations/gcp.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import sys
from copy import deepcopy
from datetime import datetime, timedelta, timezone
Expand All @@ -13,7 +14,6 @@
from sentry_sdk.utils import (
AnnotatedValue,
capture_internal_exceptions,
ensure_integration_enabled,
event_from_exception,
logger,
TimeoutThread,
Expand All @@ -39,12 +39,14 @@

def _wrap_func(func):
# type: (F) -> F
@ensure_integration_enabled(GcpIntegration, func)
@functools.wraps(func)
def sentry_func(functionhandler, gcp_event, *args, **kwargs):
# type: (Any, Any, *Any, **Any) -> Any
client = sentry_sdk.get_client()

integration = client.get_integration(GcpIntegration)
if integration is None:
return func(functionhandler, gcp_event, *args, **kwargs)

configured_time = environ.get("FUNCTION_TIMEOUT_SEC")
if not configured_time:
Expand Down
8 changes: 4 additions & 4 deletions sentry_sdk/integrations/huggingface_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from sentry_sdk.utils import (
capture_internal_exceptions,
event_from_exception,
ensure_integration_enabled,
)

try:
Expand Down Expand Up @@ -55,9 +54,12 @@ def _capture_exception(exc):
def _wrap_text_generation(f):
# type: (Callable[..., Any]) -> Callable[..., Any]
@wraps(f)
@ensure_integration_enabled(HuggingfaceHubIntegration, f)
def new_text_generation(*args, **kwargs):
# type: (*Any, **Any) -> Any
integration = sentry_sdk.get_client().get_integration(HuggingfaceHubIntegration)
if integration is None:
return f(*args, **kwargs)

if "prompt" in kwargs:
prompt = kwargs["prompt"]
elif len(args) >= 2:
Expand All @@ -84,8 +86,6 @@ def new_text_generation(*args, **kwargs):
span.__exit__(None, None, None)
raise e from None

integration = sentry_sdk.get_client().get_integration(HuggingfaceHubIntegration)

with capture_internal_exceptions():
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, prompt)
Expand Down
Loading

0 comments on commit b3678cb

Please sign in to comment.