Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add unpack variants option #20

Merged
merged 8 commits into from
Sep 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions src/dbus_fast/aio/proxy_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ..message_bus import BaseMessageBus
from ..proxy_object import BaseProxyInterface, BaseProxyObject
from ..signature import Variant
from ..signature import unpack_variants as unpack


class ProxyInterface(BaseProxyInterface):
Expand Down Expand Up @@ -74,7 +75,9 @@ class ProxyInterface(BaseProxyInterface):
"""

def _add_method(self, intr_method):
async def method_fn(*args, flags=MessageFlag.NONE):
async def method_fn(
*args, flags=MessageFlag.NONE, unpack_variants: bool = False
):
input_body, unix_fds = replace_fds_with_idx(
intr_method.in_signature, list(args)
)
Expand Down Expand Up @@ -103,16 +106,24 @@ async def method_fn(*args, flags=MessageFlag.NONE):

if not out_len:
return None
elif out_len == 1:

if unpack_variants:
body = unpack(body)

if out_len == 1:
return body[0]
else:
return body
return body

method_name = f"call_{BaseProxyInterface._to_snake_case(intr_method.name)}"
setattr(self, method_name, method_fn)

def _add_property(self, intr_property):
async def property_getter():
def _add_property(
self,
intr_property,
):
async def property_getter(
*, flags=MessageFlag.NONE, unpack_variants: bool = False
):
msg = await self.bus.call(
bdraco marked this conversation as resolved.
Show resolved Hide resolved
Message(
destination=self.bus_name,
Expand All @@ -133,7 +144,11 @@ async def property_getter():
msg,
)

return replace_idx_with_fds("v", msg.body, msg.unix_fds)[0].value
body = replace_idx_with_fds("v", msg.body, msg.unix_fds)[0].value

if unpack_variants:
return unpack(body)
return body

async def property_setter(val):
variant = Variant(intr_property.signature, val)
Expand Down
31 changes: 21 additions & 10 deletions src/dbus_fast/glib/proxy_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ..message_bus import BaseMessageBus
from ..proxy_object import BaseProxyInterface, BaseProxyObject
from ..signature import Variant
from ..signature import unpack_variants as unpack

# glib is optional
try:
Expand Down Expand Up @@ -113,7 +114,7 @@ def _add_method(self, intr_method):
in_len = len(intr_method.in_args)
out_len = len(intr_method.out_args)

def method_fn(*args):
def method_fn(*args, unpack_variants: bool = False):
if len(args) != in_len + 1:
raise TypeError(
f"method {intr_method.name} expects {in_len} arguments and a callback (got {len(args)} args)"
Expand All @@ -136,7 +137,10 @@ def call_notify(msg, err):
except DBusError as e:
err = e

callback(msg.body, err)
if unpack_variants:
callback(unpack(msg.body), err)
else:
callback(msg.body, err)

self.bus.call(
Message(
Expand All @@ -150,7 +154,7 @@ def call_notify(msg, err):
call_notify,
)

def method_fn_sync(*args):
def method_fn_sync(*args, unpack_variants: bool = False):
main = GLib.MainLoop()
call_error = None
call_body = None
Expand All @@ -171,10 +175,13 @@ def callback(body, err):

if not out_len:
return None
elif out_len == 1:

if unpack_variants:
call_body = unpack(call_body)

if out_len == 1:
return call_body[0]
else:
return call_body
return call_body

method_name = f"call_{BaseProxyInterface._to_snake_case(intr_method.name)}"
method_name_sync = f"{method_name}_sync"
Expand All @@ -183,7 +190,7 @@ def callback(body, err):
setattr(self, method_name_sync, method_fn_sync)

def _add_property(self, intr_property):
def property_getter(callback):
def property_getter(callback, *, unpack_variants: bool = False):
def call_notify(msg, err):
if err:
callback(None, err)
Expand All @@ -204,8 +211,10 @@ def call_notify(msg, err):
)
callback(None, err)
return

callback(variant.value, None)
if unpack_variants:
callback(unpack(variant.value), None)
else:
callback(variant.value, None)

self.bus.call(
Message(
Expand All @@ -219,7 +228,7 @@ def call_notify(msg, err):
call_notify,
)

def property_getter_sync():
def property_getter_sync(*, unpack_variants: bool = False):
property_value = None
reply_error = None

Expand All @@ -236,6 +245,8 @@ def callback(value, err):
main.run()
if reply_error:
raise reply_error
if unpack_variants:
return unpack(property_value)
return property_value

def property_setter(value, callback):
Expand Down
36 changes: 29 additions & 7 deletions src/dbus_fast/proxy_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,27 @@
import logging
import re
import xml.etree.ElementTree as ET
from typing import Coroutine, List, Type, Union
from dataclasses import dataclass
from typing import Callable, Coroutine, Dict, List, Type, Union

from . import introspection as intr
from . import message_bus
from ._private.util import replace_idx_with_fds
from .constants import ErrorType, MessageType
from .errors import DBusError, InterfaceNotFoundError
from .message import Message
from .signature import unpack_variants as unpack
from .validators import assert_bus_name_valid, assert_object_path_valid


@dataclass
class SignalHandler:
"""Signal handler."""

fn: Callable
unpack_variants: bool


class BaseProxyInterface:
"""An abstract class representing a proxy to an interface exported on the bus by another client.

Expand Down Expand Up @@ -46,7 +56,7 @@ def __init__(self, bus_name, path, introspection, bus):
self.path = path
self.introspection = introspection
self.bus = bus
self._signal_handlers = {}
self._signal_handlers: Dict[str, List[SignalHandler]] = {}
self._signal_match_rule = f"type='signal',sender={bus_name},interface={introspection.name},path={path}"

_underscorer1 = re.compile(r"(.)([A-Z][a-z]+)")
Expand Down Expand Up @@ -110,13 +120,21 @@ def _message_handler(self, msg):
return

body = replace_idx_with_fds(msg.signature, msg.body, msg.unix_fds)
no_sig = None
for handler in self._signal_handlers[msg.member]:
cb_result = handler(*body)
if handler.unpack_variants:
if not no_sig:
no_sig = unpack(body)
data = no_sig
else:
data = body

cb_result = handler.fn(*data)
if isinstance(cb_result, Coroutine):
asyncio.create_task(cb_result)

def _add_signal(self, intr_signal, interface):
def on_signal_fn(fn):
def on_signal_fn(fn, *, unpack_variants: bool = False):
fn_signature = inspect.signature(fn)
if len(fn_signature.parameters) != len(intr_signal.args) and (
inspect.Parameter.VAR_POSITIONAL
Expand All @@ -134,11 +152,15 @@ def on_signal_fn(fn):
if intr_signal.name not in self._signal_handlers:
self._signal_handlers[intr_signal.name] = []

self._signal_handlers[intr_signal.name].append(fn)
self._signal_handlers[intr_signal.name].append(
SignalHandler(fn, unpack_variants)
)

def off_signal_fn(fn):
def off_signal_fn(fn, *, unpack_variants: bool = False):
try:
i = self._signal_handlers[intr_signal.name].index(fn)
i = self._signal_handlers[intr_signal.name].index(
SignalHandler(fn, unpack_variants)
)
del self._signal_handlers[intr_signal.name][i]
if not self._signal_handlers[intr_signal.name]:
del self._signal_handlers[intr_signal.name]
Expand Down
11 changes: 11 additions & 0 deletions src/dbus_fast/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@
from .validators import is_object_path_valid


def unpack_variants(data: Any):
"""Unpack variants and remove signature info."""
if isinstance(data, Variant):
return unpack_variants(data.value)
if isinstance(data, dict):
return {k: unpack_variants(v) for k, v in data.items()}
if isinstance(data, list):
return [unpack_variants(item) for item in data]
return data


class SignatureType:
"""A class that represents a single complete type within a signature.

Expand Down
18 changes: 18 additions & 0 deletions tests/client/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dbus_fast import DBusError, aio, glib
from dbus_fast.message import MessageFlag
from dbus_fast.service import ServiceInterface, method
from dbus_fast.signature import Variant
from tests.util import check_gi_repository, skip_reason_no_gi

has_gi = check_gi_repository()
Expand Down Expand Up @@ -33,6 +34,11 @@ def ConcatStrings(self, what1: "s", what2: "s") -> "s":
def EchoThree(self, what1: "s", what2: "s", what3: "s") -> "sss":
return [what1, what2, what3]

@method()
def GetComplex(self) -> "a{sv}":
"""Return complex output."""
return {"hello": Variant("s", "world")}

@method()
def ThrowsError(self):
raise DBusError("test.error", "something went wrong")
Expand Down Expand Up @@ -81,6 +87,12 @@ async def test_aio_proxy_object():
)
assert result is None

result = await interface.call_get_complex()
assert result == {"hello": Variant("s", "world")}

result = await interface.call_get_complex(unpack_variants=True)
assert result == {"hello": "world"}

with pytest.raises(DBusError):
try:
await interface.call_throws_error()
Expand Down Expand Up @@ -120,6 +132,12 @@ def test_glib_proxy_object():
result = interface.call_echo_three_sync("hello", "there", "world")
assert result == ["hello", "there", "world"]

result = interface.call_get_complex_sync()
assert result == {"hello": Variant("s", "world")}

result = interface.call_get_complex_sync(unpack_variants=True)
assert result == {"hello": "world"}

with pytest.raises(DBusError):
try:
result = interface.call_throws_error_sync()
Expand Down
18 changes: 18 additions & 0 deletions tests/client/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from dbus_fast import DBusError, Message, aio, glib
from dbus_fast.service import PropertyAccess, ServiceInterface, dbus_property
from dbus_fast.signature import Variant
from tests.util import check_gi_repository, skip_reason_no_gi

has_gi = check_gi_repository()
Expand All @@ -27,6 +28,11 @@ def SomeProperty(self, val: "s"):
def Int64Property(self) -> "x":
return self._int64_property

@dbus_property(access=PropertyAccess.READ)
def ComplexProperty(self) -> "a{sv}":
"""Return complex output."""
return {"hello": Variant("s", "world")}

@dbus_property()
def ErrorThrowingProperty(self) -> "s":
raise DBusError(self.error_name, self.error_text)
Expand Down Expand Up @@ -59,6 +65,12 @@ async def test_aio_properties():
await interface.set_some_property("different")
assert service_interface._some_property == "different"

prop = await interface.get_complex_property()
assert prop == {"hello": Variant("s", "world")}

prop = await interface.get_complex_property(unpack_variants=True)
assert prop == {"hello": "world"}

with pytest.raises(DBusError):
try:
prop = await interface.get_error_throwing_property()
Expand Down Expand Up @@ -102,6 +114,12 @@ def test_glib_properties():
interface.set_some_property_sync("different")
assert service_interface._some_property == "different"

prop = interface.get_complex_property_sync()
assert prop == {"hello": Variant("s", "world")}

prop = interface.get_complex_property_sync(unpack_variants=True)
assert prop == {"hello": "world"}

with pytest.raises(DBusError):
try:
prop = interface.get_error_throwing_property_sync()
Expand Down
Loading