From 123887e96babe0bb2e16d1272b0996342efa154e Mon Sep 17 00:00:00 2001 From: Reiley Yang Date: Thu, 18 Jul 2019 00:12:28 -0700 Subject: [PATCH] type hint --- .../src/opentelemetry/context/__init__.py | 4 +++- .../opentelemetry/context/async_context.py | 13 +++++++------ .../src/opentelemetry/context/base_context.py | 19 +++++++++++++------ .../context/thread_local_context.py | 9 +++++---- 4 files changed, 28 insertions(+), 17 deletions(-) diff --git a/opentelemetry-api/src/opentelemetry/context/__init__.py b/opentelemetry-api/src/opentelemetry/context/__init__.py index e9b8678fb06..c2fe840cb44 100644 --- a/opentelemetry-api/src/opentelemetry/context/__init__.py +++ b/opentelemetry-api/src/opentelemetry/context/__init__.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + from .base_context import BaseRuntimeContext __all__ = ['Context'] -Context: BaseRuntimeContext = None +Context: typing.Union[BaseRuntimeContext, None] = None try: from .async_context import AsyncRuntimeContext diff --git a/opentelemetry-api/src/opentelemetry/context/async_context.py b/opentelemetry-api/src/opentelemetry/context/async_context.py index 614cfaacc0d..6e5f0c0fce3 100644 --- a/opentelemetry-api/src/opentelemetry/context/async_context.py +++ b/opentelemetry-api/src/opentelemetry/context/async_context.py @@ -12,23 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextvars +from contextvars import ContextVar +import typing from .base_context import BaseRuntimeContext class AsyncRuntimeContext(BaseRuntimeContext): class Slot(BaseRuntimeContext.Slot): - def __init__(self, name, default): + def __init__(self, name: str, default: typing.Any): # pylint: disable=super-init-not-called self.name = name - self.contextvar = contextvars.ContextVar(name) + self.contextvar: typing.Any = ContextVar(name) self.default = default if callable(default) else (lambda: default) - def clear(self): + def clear(self) -> None: self.contextvar.set(self.default()) - def get(self): + def get(self) -> typing.Any: try: return self.contextvar.get() except LookupError: @@ -36,5 +37,5 @@ def get(self): self.set(value) return value - def set(self, value): + def set(self, value: typing.Any) -> None: self.contextvar.set(value) diff --git a/opentelemetry-api/src/opentelemetry/context/base_context.py b/opentelemetry-api/src/opentelemetry/context/base_context.py index 67d42a10b49..5cc1794cd06 100644 --- a/opentelemetry-api/src/opentelemetry/context/base_context.py +++ b/opentelemetry-api/src/opentelemetry/context/base_context.py @@ -60,7 +60,7 @@ def register_slot(cls, name: str, default: typing.Any = None) -> 'Slot': cls._slots[name] = slot return slot - def apply(self, snapshot) -> None: + def apply(self, snapshot: typing.Dict[str, typing.Any]) -> None: """Set the current context from a given snapshot dictionary""" for name in snapshot: @@ -75,24 +75,31 @@ def snapshot(self) -> typing.Dict[str, typing.Any]: def __repr__(self) -> str: return '{}({})'.format(type(self).__name__, self.snapshot()) - def __getattr__(self, name) -> typing.Any: + def __getattr__(self, name: str) -> typing.Any: if name not in self._slots: self.register_slot(name, None) slot = self._slots[name] return slot.get() - def __setattr__(self, name, value) -> None: + def __setattr__(self, name: str, value: typing.Any) -> None: if name not in self._slots: self.register_slot(name, None) slot = self._slots[name] slot.set(value) - def with_current_context(self, func: typing.Callable) -> typing.Callable: - """Capture the current context and apply it to the provided func""" + def with_current_context( + self, + func: typing.Callable[..., typing.Any], + ) -> typing.Callable[..., typing.Any]: + """Capture the current context and apply it to the provided func. + """ caller_context = self.snapshot() - def call_with_current_context(*args, **kwargs) -> typing.Any: + def call_with_current_context( + *args: typing.Any, + **kwargs: typing.Any, + ) -> typing.Any: try: backup_context = self.snapshot() self.apply(caller_context) diff --git a/opentelemetry-api/src/opentelemetry/context/thread_local_context.py b/opentelemetry-api/src/opentelemetry/context/thread_local_context.py index 9830aa83b03..25f07c98069 100644 --- a/opentelemetry-api/src/opentelemetry/context/thread_local_context.py +++ b/opentelemetry-api/src/opentelemetry/context/thread_local_context.py @@ -13,6 +13,7 @@ # limitations under the License. import threading +import typing from .base_context import BaseRuntimeContext @@ -21,15 +22,15 @@ class ThreadLocalRuntimeContext(BaseRuntimeContext): class Slot(BaseRuntimeContext.Slot): _thread_local = threading.local() - def __init__(self, name, default): + def __init__(self, name: str, default: typing.Any): # pylint: disable=super-init-not-called self.name = name self.default = default if callable(default) else (lambda: default) - def clear(self): + def clear(self) -> None: setattr(self._thread_local, self.name, self.default()) - def get(self): + def get(self) -> typing.Any: try: return getattr(self._thread_local, self.name) except AttributeError: @@ -37,5 +38,5 @@ def get(self): self.set(value) return value - def set(self, value): + def set(self, value: typing.Any) -> None: setattr(self._thread_local, self.name, value)