Skip to content

Commit

Permalink
type hint
Browse files Browse the repository at this point in the history
  • Loading branch information
reyang committed Jul 25, 2019
1 parent 533fd2a commit 123887e
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 17 deletions.
4 changes: 3 additions & 1 deletion opentelemetry-api/src/opentelemetry/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import typing

from .base_context import BaseRuntimeContext

__all__ = ['Context']

Context: BaseRuntimeContext = None
Context: typing.Union[BaseRuntimeContext, None] = None

try:
from .async_context import AsyncRuntimeContext
Expand Down
13 changes: 7 additions & 6 deletions opentelemetry-api/src/opentelemetry/context/async_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextvars
from contextvars import ContextVar
import typing

from .base_context import BaseRuntimeContext


class AsyncRuntimeContext(BaseRuntimeContext):
class Slot(BaseRuntimeContext.Slot):
def __init__(self, name, default):
def __init__(self, name: str, default: typing.Any):
# pylint: disable=super-init-not-called
self.name = name
self.contextvar = contextvars.ContextVar(name)
self.contextvar: typing.Any = ContextVar(name)
self.default = default if callable(default) else (lambda: default)

def clear(self):
def clear(self) -> None:
self.contextvar.set(self.default())

def get(self):
def get(self) -> typing.Any:
try:
return self.contextvar.get()
except LookupError:
value = self.default()
self.set(value)
return value

def set(self, value):
def set(self, value: typing.Any) -> None:
self.contextvar.set(value)
19 changes: 13 additions & 6 deletions opentelemetry-api/src/opentelemetry/context/base_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def register_slot(cls, name: str, default: typing.Any = None) -> 'Slot':
cls._slots[name] = slot
return slot

def apply(self, snapshot) -> None:
def apply(self, snapshot: typing.Dict[str, typing.Any]) -> None:
"""Set the current context from a given snapshot dictionary"""

for name in snapshot:
Expand All @@ -75,24 +75,31 @@ def snapshot(self) -> typing.Dict[str, typing.Any]:
def __repr__(self) -> str:
return '{}({})'.format(type(self).__name__, self.snapshot())

def __getattr__(self, name) -> typing.Any:
def __getattr__(self, name: str) -> typing.Any:
if name not in self._slots:
self.register_slot(name, None)
slot = self._slots[name]
return slot.get()

def __setattr__(self, name, value) -> None:
def __setattr__(self, name: str, value: typing.Any) -> None:
if name not in self._slots:
self.register_slot(name, None)
slot = self._slots[name]
slot.set(value)

def with_current_context(self, func: typing.Callable) -> typing.Callable:
"""Capture the current context and apply it to the provided func"""
def with_current_context(
self,
func: typing.Callable[..., typing.Any],
) -> typing.Callable[..., typing.Any]:
"""Capture the current context and apply it to the provided func.
"""

caller_context = self.snapshot()

def call_with_current_context(*args, **kwargs) -> typing.Any:
def call_with_current_context(
*args: typing.Any,
**kwargs: typing.Any,
) -> typing.Any:
try:
backup_context = self.snapshot()
self.apply(caller_context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import threading
import typing

from .base_context import BaseRuntimeContext

Expand All @@ -21,21 +22,21 @@ class ThreadLocalRuntimeContext(BaseRuntimeContext):
class Slot(BaseRuntimeContext.Slot):
_thread_local = threading.local()

def __init__(self, name, default):
def __init__(self, name: str, default: typing.Any):
# pylint: disable=super-init-not-called
self.name = name
self.default = default if callable(default) else (lambda: default)

def clear(self):
def clear(self) -> None:
setattr(self._thread_local, self.name, self.default())

def get(self):
def get(self) -> typing.Any:
try:
return getattr(self._thread_local, self.name)
except AttributeError:
value = self.default()
self.set(value)
return value

def set(self, value):
def set(self, value: typing.Any) -> None:
setattr(self._thread_local, self.name, value)

0 comments on commit 123887e

Please sign in to comment.