Skip to content

Commit

Permalink
feat: add unpack variants option (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
mdegat01 authored Sep 20, 2022
1 parent 1209048 commit cfad28b
Show file tree
Hide file tree
Showing 8 changed files with 244 additions and 25 deletions.
29 changes: 22 additions & 7 deletions src/dbus_fast/aio/proxy_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ..message_bus import BaseMessageBus
from ..proxy_object import BaseProxyInterface, BaseProxyObject
from ..signature import Variant
from ..signature import unpack_variants as unpack


class ProxyInterface(BaseProxyInterface):
Expand Down Expand Up @@ -74,7 +75,9 @@ class ProxyInterface(BaseProxyInterface):
"""

def _add_method(self, intr_method):
async def method_fn(*args, flags=MessageFlag.NONE):
async def method_fn(
*args, flags=MessageFlag.NONE, unpack_variants: bool = False
):
input_body, unix_fds = replace_fds_with_idx(
intr_method.in_signature, list(args)
)
Expand Down Expand Up @@ -103,16 +106,24 @@ async def method_fn(*args, flags=MessageFlag.NONE):

if not out_len:
return None
elif out_len == 1:

if unpack_variants:
body = unpack(body)

if out_len == 1:
return body[0]
else:
return body
return body

method_name = f"call_{BaseProxyInterface._to_snake_case(intr_method.name)}"
setattr(self, method_name, method_fn)

def _add_property(self, intr_property):
async def property_getter():
def _add_property(
self,
intr_property,
):
async def property_getter(
*, flags=MessageFlag.NONE, unpack_variants: bool = False
):
msg = await self.bus.call(
Message(
destination=self.bus_name,
Expand All @@ -133,7 +144,11 @@ async def property_getter():
msg,
)

return replace_idx_with_fds("v", msg.body, msg.unix_fds)[0].value
body = replace_idx_with_fds("v", msg.body, msg.unix_fds)[0].value

if unpack_variants:
return unpack(body)
return body

async def property_setter(val):
variant = Variant(intr_property.signature, val)
Expand Down
31 changes: 21 additions & 10 deletions src/dbus_fast/glib/proxy_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ..message_bus import BaseMessageBus
from ..proxy_object import BaseProxyInterface, BaseProxyObject
from ..signature import Variant
from ..signature import unpack_variants as unpack

# glib is optional
try:
Expand Down Expand Up @@ -113,7 +114,7 @@ def _add_method(self, intr_method):
in_len = len(intr_method.in_args)
out_len = len(intr_method.out_args)

def method_fn(*args):
def method_fn(*args, unpack_variants: bool = False):
if len(args) != in_len + 1:
raise TypeError(
f"method {intr_method.name} expects {in_len} arguments and a callback (got {len(args)} args)"
Expand All @@ -136,7 +137,10 @@ def call_notify(msg, err):
except DBusError as e:
err = e

callback(msg.body, err)
if unpack_variants:
callback(unpack(msg.body), err)
else:
callback(msg.body, err)

self.bus.call(
Message(
Expand All @@ -150,7 +154,7 @@ def call_notify(msg, err):
call_notify,
)

def method_fn_sync(*args):
def method_fn_sync(*args, unpack_variants: bool = False):
main = GLib.MainLoop()
call_error = None
call_body = None
Expand All @@ -171,10 +175,13 @@ def callback(body, err):

if not out_len:
return None
elif out_len == 1:

if unpack_variants:
call_body = unpack(call_body)

if out_len == 1:
return call_body[0]
else:
return call_body
return call_body

method_name = f"call_{BaseProxyInterface._to_snake_case(intr_method.name)}"
method_name_sync = f"{method_name}_sync"
Expand All @@ -183,7 +190,7 @@ def callback(body, err):
setattr(self, method_name_sync, method_fn_sync)

def _add_property(self, intr_property):
def property_getter(callback):
def property_getter(callback, *, unpack_variants: bool = False):
def call_notify(msg, err):
if err:
callback(None, err)
Expand All @@ -204,8 +211,10 @@ def call_notify(msg, err):
)
callback(None, err)
return

callback(variant.value, None)
if unpack_variants:
callback(unpack(variant.value), None)
else:
callback(variant.value, None)

self.bus.call(
Message(
Expand All @@ -219,7 +228,7 @@ def call_notify(msg, err):
call_notify,
)

def property_getter_sync():
def property_getter_sync(*, unpack_variants: bool = False):
property_value = None
reply_error = None

Expand All @@ -236,6 +245,8 @@ def callback(value, err):
main.run()
if reply_error:
raise reply_error
if unpack_variants:
return unpack(property_value)
return property_value

def property_setter(value, callback):
Expand Down
36 changes: 29 additions & 7 deletions src/dbus_fast/proxy_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,27 @@
import logging
import re
import xml.etree.ElementTree as ET
from typing import Coroutine, List, Type, Union
from dataclasses import dataclass
from typing import Callable, Coroutine, Dict, List, Type, Union

from . import introspection as intr
from . import message_bus
from ._private.util import replace_idx_with_fds
from .constants import ErrorType, MessageType
from .errors import DBusError, InterfaceNotFoundError
from .message import Message
from .signature import unpack_variants as unpack
from .validators import assert_bus_name_valid, assert_object_path_valid


@dataclass
class SignalHandler:
"""Signal handler."""

fn: Callable
unpack_variants: bool


class BaseProxyInterface:
"""An abstract class representing a proxy to an interface exported on the bus by another client.
Expand Down Expand Up @@ -46,7 +56,7 @@ def __init__(self, bus_name, path, introspection, bus):
self.path = path
self.introspection = introspection
self.bus = bus
self._signal_handlers = {}
self._signal_handlers: Dict[str, List[SignalHandler]] = {}
self._signal_match_rule = f"type='signal',sender={bus_name},interface={introspection.name},path={path}"

_underscorer1 = re.compile(r"(.)([A-Z][a-z]+)")
Expand Down Expand Up @@ -110,13 +120,21 @@ def _message_handler(self, msg):
return

body = replace_idx_with_fds(msg.signature, msg.body, msg.unix_fds)
no_sig = None
for handler in self._signal_handlers[msg.member]:
cb_result = handler(*body)
if handler.unpack_variants:
if not no_sig:
no_sig = unpack(body)
data = no_sig
else:
data = body

cb_result = handler.fn(*data)
if isinstance(cb_result, Coroutine):
asyncio.create_task(cb_result)

def _add_signal(self, intr_signal, interface):
def on_signal_fn(fn):
def on_signal_fn(fn, *, unpack_variants: bool = False):
fn_signature = inspect.signature(fn)
if len(fn_signature.parameters) != len(intr_signal.args) and (
inspect.Parameter.VAR_POSITIONAL
Expand All @@ -134,11 +152,15 @@ def on_signal_fn(fn):
if intr_signal.name not in self._signal_handlers:
self._signal_handlers[intr_signal.name] = []

self._signal_handlers[intr_signal.name].append(fn)
self._signal_handlers[intr_signal.name].append(
SignalHandler(fn, unpack_variants)
)

def off_signal_fn(fn):
def off_signal_fn(fn, *, unpack_variants: bool = False):
try:
i = self._signal_handlers[intr_signal.name].index(fn)
i = self._signal_handlers[intr_signal.name].index(
SignalHandler(fn, unpack_variants)
)
del self._signal_handlers[intr_signal.name][i]
if not self._signal_handlers[intr_signal.name]:
del self._signal_handlers[intr_signal.name]
Expand Down
11 changes: 11 additions & 0 deletions src/dbus_fast/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@
from .validators import is_object_path_valid


def unpack_variants(data: Any):
"""Unpack variants and remove signature info."""
if isinstance(data, Variant):
return unpack_variants(data.value)
if isinstance(data, dict):
return {k: unpack_variants(v) for k, v in data.items()}
if isinstance(data, list):
return [unpack_variants(item) for item in data]
return data


class SignatureType:
"""A class that represents a single complete type within a signature.
Expand Down
18 changes: 18 additions & 0 deletions tests/client/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dbus_fast import DBusError, aio, glib
from dbus_fast.message import MessageFlag
from dbus_fast.service import ServiceInterface, method
from dbus_fast.signature import Variant
from tests.util import check_gi_repository, skip_reason_no_gi

has_gi = check_gi_repository()
Expand Down Expand Up @@ -33,6 +34,11 @@ def ConcatStrings(self, what1: "s", what2: "s") -> "s":
def EchoThree(self, what1: "s", what2: "s", what3: "s") -> "sss":
return [what1, what2, what3]

@method()
def GetComplex(self) -> "a{sv}":
"""Return complex output."""
return {"hello": Variant("s", "world")}

@method()
def ThrowsError(self):
raise DBusError("test.error", "something went wrong")
Expand Down Expand Up @@ -81,6 +87,12 @@ async def test_aio_proxy_object():
)
assert result is None

result = await interface.call_get_complex()
assert result == {"hello": Variant("s", "world")}

result = await interface.call_get_complex(unpack_variants=True)
assert result == {"hello": "world"}

with pytest.raises(DBusError):
try:
await interface.call_throws_error()
Expand Down Expand Up @@ -120,6 +132,12 @@ def test_glib_proxy_object():
result = interface.call_echo_three_sync("hello", "there", "world")
assert result == ["hello", "there", "world"]

result = interface.call_get_complex_sync()
assert result == {"hello": Variant("s", "world")}

result = interface.call_get_complex_sync(unpack_variants=True)
assert result == {"hello": "world"}

with pytest.raises(DBusError):
try:
result = interface.call_throws_error_sync()
Expand Down
18 changes: 18 additions & 0 deletions tests/client/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from dbus_fast import DBusError, Message, aio, glib
from dbus_fast.service import PropertyAccess, ServiceInterface, dbus_property
from dbus_fast.signature import Variant
from tests.util import check_gi_repository, skip_reason_no_gi

has_gi = check_gi_repository()
Expand All @@ -27,6 +28,11 @@ def SomeProperty(self, val: "s"):
def Int64Property(self) -> "x":
return self._int64_property

@dbus_property(access=PropertyAccess.READ)
def ComplexProperty(self) -> "a{sv}":
"""Return complex output."""
return {"hello": Variant("s", "world")}

@dbus_property()
def ErrorThrowingProperty(self) -> "s":
raise DBusError(self.error_name, self.error_text)
Expand Down Expand Up @@ -59,6 +65,12 @@ async def test_aio_properties():
await interface.set_some_property("different")
assert service_interface._some_property == "different"

prop = await interface.get_complex_property()
assert prop == {"hello": Variant("s", "world")}

prop = await interface.get_complex_property(unpack_variants=True)
assert prop == {"hello": "world"}

with pytest.raises(DBusError):
try:
prop = await interface.get_error_throwing_property()
Expand Down Expand Up @@ -102,6 +114,12 @@ def test_glib_properties():
interface.set_some_property_sync("different")
assert service_interface._some_property == "different"

prop = interface.get_complex_property_sync()
assert prop == {"hello": Variant("s", "world")}

prop = interface.get_complex_property_sync(unpack_variants=True)
assert prop == {"hello": "world"}

with pytest.raises(DBusError):
try:
prop = interface.get_error_throwing_property_sync()
Expand Down
Loading

0 comments on commit cfad28b

Please sign in to comment.