From 1f93a36b5786112cce8ac5c3c0b7b4d92ac8a6c1 Mon Sep 17 00:00:00 2001 From: Chris Kleinknecht Date: Fri, 19 Jul 2019 10:10:42 -0700 Subject: [PATCH] Type hint edits for #57 (#59) --- mypy-relaxed.ini | 2 +- mypy.ini | 2 +- .../src/opentelemetry/context/__init__.py | 4 ++- .../opentelemetry/context/async_context.py | 19 +++++------ .../src/opentelemetry/context/base_context.py | 32 +++++++++++-------- .../context/thread_local_context.py | 18 ++++++----- 6 files changed, 44 insertions(+), 33 deletions(-) diff --git a/mypy-relaxed.ini b/mypy-relaxed.ini index 9cd56789a1e..205688353e3 100644 --- a/mypy-relaxed.ini +++ b/mypy-relaxed.ini @@ -3,7 +3,7 @@ [mypy] disallow_any_unimported = True ; disallow_any_expr = True -; disallow_any_decorated = True + disallow_any_decorated = True ; disallow_any_explicit = True disallow_any_generics = True disallow_subclassing_any = True diff --git a/mypy.ini b/mypy.ini index abb7aba2961..ba375b62b1d 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,7 +1,7 @@ [mypy] disallow_any_unimported = True disallow_any_expr = True -; disallow_any_decorated = True + disallow_any_decorated = True ; disallow_any_explicit = True disallow_any_generics = True disallow_subclassing_any = True diff --git a/opentelemetry-api/src/opentelemetry/context/__init__.py b/opentelemetry-api/src/opentelemetry/context/__init__.py index c2fe840cb44..db975f9e510 100644 --- a/opentelemetry-api/src/opentelemetry/context/__init__.py +++ b/opentelemetry-api/src/opentelemetry/context/__init__.py @@ -16,9 +16,11 @@ from .base_context import BaseRuntimeContext + __all__ = ['Context'] -Context: typing.Union[BaseRuntimeContext, None] = None + +Context: typing.Optional[BaseRuntimeContext] try: from .async_context import AsyncRuntimeContext diff --git a/opentelemetry-api/src/opentelemetry/context/async_context.py b/opentelemetry-api/src/opentelemetry/context/async_context.py index 6e5f0c0fce3..f118870ecea 100644 --- a/opentelemetry-api/src/opentelemetry/context/async_context.py +++ b/opentelemetry-api/src/opentelemetry/context/async_context.py @@ -12,24 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -from contextvars import ContextVar import typing +from contextvars import ContextVar -from .base_context import BaseRuntimeContext +from . import base_context -class AsyncRuntimeContext(BaseRuntimeContext): - class Slot(BaseRuntimeContext.Slot): - def __init__(self, name: str, default: typing.Any): +class AsyncRuntimeContext(base_context.BaseRuntimeContext): + class Slot(base_context.BaseRuntimeContext.Slot): + def __init__(self, name: str, default: 'object'): # pylint: disable=super-init-not-called self.name = name - self.contextvar: typing.Any = ContextVar(name) - self.default = default if callable(default) else (lambda: default) + self.contextvar: 'ContextVar[object]' = ContextVar(name) + self.default: typing.Callable[..., object] + self.default = base_context.wrap_callable(default) def clear(self) -> None: self.contextvar.set(self.default()) - def get(self) -> typing.Any: + def get(self) -> 'object': try: return self.contextvar.get() except LookupError: @@ -37,5 +38,5 @@ def get(self) -> typing.Any: self.set(value) return value - def set(self, value: typing.Any) -> None: + def set(self, value: 'object') -> None: self.contextvar.set(value) diff --git a/opentelemetry-api/src/opentelemetry/context/base_context.py b/opentelemetry-api/src/opentelemetry/context/base_context.py index 5cc1794cd06..35ee179a4b8 100644 --- a/opentelemetry-api/src/opentelemetry/context/base_context.py +++ b/opentelemetry-api/src/opentelemetry/context/base_context.py @@ -16,18 +16,24 @@ import typing +def wrap_callable(target: 'object') -> typing.Callable[[], object]: + if callable(target): + return target + return lambda: target + + class BaseRuntimeContext: class Slot: - def __init__(self, name: str, default: typing.Any): + def __init__(self, name: str, default: 'object'): raise NotImplementedError def clear(self) -> None: raise NotImplementedError - def get(self) -> typing.Any: + def get(self) -> 'object': raise NotImplementedError - def set(self, value: typing.Any) -> None: + def set(self, value: 'object') -> None: raise NotImplementedError _lock = threading.Lock() @@ -42,7 +48,7 @@ def clear(cls) -> None: slot.clear() @classmethod - def register_slot(cls, name: str, default: typing.Any = None) -> 'Slot': + def register_slot(cls, name: str, default: 'object' = None) -> 'Slot': """Register a context slot with an optional default value. :type name: str @@ -60,13 +66,13 @@ def register_slot(cls, name: str, default: typing.Any = None) -> 'Slot': cls._slots[name] = slot return slot - def apply(self, snapshot: typing.Dict[str, typing.Any]) -> None: + def apply(self, snapshot: typing.Dict[str, 'object']) -> None: """Set the current context from a given snapshot dictionary""" for name in snapshot: setattr(self, name, snapshot[name]) - def snapshot(self) -> typing.Dict[str, typing.Any]: + def snapshot(self) -> typing.Dict[str, 'object']: """Return a dictionary of current slots by reference.""" keys = self._slots.keys() @@ -75,13 +81,13 @@ def snapshot(self) -> typing.Dict[str, typing.Any]: def __repr__(self) -> str: return '{}({})'.format(type(self).__name__, self.snapshot()) - def __getattr__(self, name: str) -> typing.Any: + def __getattr__(self, name: str) -> 'object': if name not in self._slots: self.register_slot(name, None) slot = self._slots[name] return slot.get() - def __setattr__(self, name: str, value: typing.Any) -> None: + def __setattr__(self, name: str, value: 'object') -> None: if name not in self._slots: self.register_slot(name, None) slot = self._slots[name] @@ -89,17 +95,17 @@ def __setattr__(self, name: str, value: typing.Any) -> None: def with_current_context( self, - func: typing.Callable[..., typing.Any], - ) -> typing.Callable[..., typing.Any]: + func: typing.Callable[..., 'object'], + ) -> typing.Callable[..., 'object']: """Capture the current context and apply it to the provided func. """ caller_context = self.snapshot() def call_with_current_context( - *args: typing.Any, - **kwargs: typing.Any, - ) -> typing.Any: + *args: 'object', + **kwargs: 'object', + ) -> 'object': try: backup_context = self.snapshot() self.apply(caller_context) diff --git a/opentelemetry-api/src/opentelemetry/context/thread_local_context.py b/opentelemetry-api/src/opentelemetry/context/thread_local_context.py index 25f07c98069..dd11128b7ac 100644 --- a/opentelemetry-api/src/opentelemetry/context/thread_local_context.py +++ b/opentelemetry-api/src/opentelemetry/context/thread_local_context.py @@ -15,28 +15,30 @@ import threading import typing -from .base_context import BaseRuntimeContext +from . import base_context -class ThreadLocalRuntimeContext(BaseRuntimeContext): - class Slot(BaseRuntimeContext.Slot): +class ThreadLocalRuntimeContext(base_context.BaseRuntimeContext): + class Slot(base_context.BaseRuntimeContext.Slot): _thread_local = threading.local() - def __init__(self, name: str, default: typing.Any): + def __init__(self, name: str, default: 'object'): # pylint: disable=super-init-not-called self.name = name - self.default = default if callable(default) else (lambda: default) + self.default: typing.Callable[..., object] + self.default = base_context.wrap_callable(default) def clear(self) -> None: setattr(self._thread_local, self.name, self.default()) - def get(self) -> typing.Any: + def get(self) -> 'object': try: - return getattr(self._thread_local, self.name) + got: object = getattr(self._thread_local, self.name) + return got except AttributeError: value = self.default() self.set(value) return value - def set(self, value: typing.Any) -> None: + def set(self, value: 'object') -> None: setattr(self._thread_local, self.name, value)