Skip to content

Commit

Permalink
more type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
totaam committed Jun 13, 2023
1 parent 99f3440 commit e783ebf
Show file tree
Hide file tree
Showing 15 changed files with 83 additions and 79 deletions.
4 changes: 2 additions & 2 deletions xpra/net/bytestreams.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
import errno
import socket
from typing import Dict, Any, Optional, Union
from typing import Dict, Any, Optional, Union, Callable

from xpra.net.common import ConnectionClosedException, IP_SOCKTYPES, TCP_SOCKTYPES
from xpra.util import envint, envbool, hasenv, csv
Expand Down Expand Up @@ -77,7 +77,7 @@ def can_retry(e) -> Union[bool,str]:
raise ConnectionClosedException(e) from None
return False

def untilConcludes(is_active_cb, can_retry_cb, f, *a, **kw):
def untilConcludes(is_active_cb:Callable, can_retry_cb:Callable, f:Callable, *a, **kw):
while is_active_cb():
try:
return f(*a, **kw)
Expand Down
46 changes: 23 additions & 23 deletions xpra/net/file_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import uuid
from time import monotonic
from dataclasses import dataclass
from typing import Dict, Any, Optional, Callable, Set
from typing import Dict, Any, Optional, Callable, Set, Tuple

from xpra.child_reaper import getChildReaper
from xpra.os_util import bytestostr, strtobytes, umask_context, POSIX, WIN32
Expand Down Expand Up @@ -40,15 +40,15 @@
ACCEPT = 1 #the file / URL will be sent
OPEN = 2 #don't send, open on sender

def osclose(fd):
def osclose(fd:int) -> None:
try:
os.close(fd)
except OSError as e:
filelog("os.close(%s)", fd, exc_info=True)
filelog.error("Error closing file download:")
filelog.estr(e)

def basename(filename):
def basename(filename:str) -> str:
#we can't use os.path.basename,
#because the remote end may have sent us a filename
#which is using a different pathsep
Expand All @@ -66,7 +66,7 @@ def basename(filename):
tmp += char
return tmp

def safe_open_download_file(basefilename, mimetype):
def safe_open_download_file(basefilename:str, mimetype:str):
from xpra.platform.paths import get_download_dir # pylint: disable=import-outside-toplevel
dd = os.path.expanduser(get_download_dir())
filename = os.path.abspath(os.path.join(dd, basename(basefilename)))
Expand Down Expand Up @@ -124,13 +124,13 @@ class FileTransferAttributes:
def __init__(self):
self.init_attributes()

def init_opts(self, opts, can_ask=True):
def init_opts(self, opts, can_ask=True) -> None:
#get the settings from a config object
self.init_attributes(opts.file_transfer, opts.file_size_limit,
opts.printing, opts.open_files, opts.open_url, opts.open_command, can_ask)

def init_attributes(self, file_transfer="no", file_size_limit=10, printing="no",
open_files="no", open_url="no", open_command=None, can_ask=True):
open_files="no", open_url="no", open_command=None, can_ask=True) -> None:
filelog("file transfer: init_attributes%s",
(file_transfer, file_size_limit, printing, open_files, open_url, open_command, can_ask))
def pbool(name, v):
Expand Down Expand Up @@ -206,7 +206,7 @@ class FileTransferHandler(FileTransferAttributes):
used by both clients and server to share the common code and attributes
"""

def init_attributes(self, *args):
def init_attributes(self, *args) -> None:
super().init_attributes(*args)
self.remote_file_transfer = False
self.remote_file_transfer_ask = False
Expand All @@ -230,7 +230,7 @@ def init_attributes(self, *args):
self.idle_add = GLib.idle_add
self.source_remove = GLib.source_remove

def cleanup(self):
def cleanup(self) -> None:
for t in self.pending_send_data_timers.values():
self.source_remove(t)
self.pending_send_data_timers = {}
Expand All @@ -247,7 +247,7 @@ def cleanup(self):
self.init_attributes()


def parse_file_transfer_caps(self, c):
def parse_file_transfer_caps(self, c) -> None:
fc = c.dictget("file")
if fc:
fc = typedict(fc)
Expand Down Expand Up @@ -278,7 +278,7 @@ def parse_file_transfer_caps(self, c):
self.remote_file_chunks = max(0, min(self.remote_file_size_limit, c.intget("file-chunks")))
self.dump_remote_caps()

def dump_remote_caps(self):
def dump_remote_caps(self) -> None:
filelog("file transfer remote caps: file-transfer=%-5s (ask=%s)",
self.remote_file_transfer, self.remote_file_transfer_ask)
filelog("file transfer remote caps: printing=%-5s (ask=%s)",
Expand Down Expand Up @@ -306,7 +306,7 @@ def get_info(self) -> Dict[str,Any]:
return info


def digest_mismatch(self, filename:str, digest, expected_digest):
def digest_mismatch(self, filename:str, digest, expected_digest) -> None:
filelog.error(f"Error: data does not match, invalid {digest.name} file digest")
filelog.error(f" for {filename!r}")
filelog.error(f" received {digest.hexdigest()}")
Expand All @@ -318,7 +318,7 @@ def digest_mismatch(self, filename:str, digest, expected_digest):
filelog.error(f"Error: failed to delete uploaded file {filename}")


def _check_chunk_receiving(self, chunk_id:int, chunk_no:int):
def _check_chunk_receiving(self, chunk_id:int, chunk_no:int) -> None:
chunk_state = self.receive_chunks_in_progress.get(chunk_id)
filelog("_check_chunk_receiving(%s, %s) chunk_state=%s", chunk_id, chunk_no, chunk_state)
if not chunk_state:
Expand All @@ -332,15 +332,15 @@ def _check_chunk_receiving(self, chunk_id:int, chunk_no:int):
filelog.error(f"Error: chunked file transfer f{chunk_id} timed out")
self.receive_chunks_in_progress.pop(chunk_id, None)

def cancel_download(self, send_id:str, message="Cancelled"):
def cancel_download(self, send_id:str, message="Cancelled") -> None:
filelog("cancel_download(%s, %s)", send_id, message)
for chunk_id, chunk_state in dict(self.receive_chunks_in_progress).items():
if chunk_state.send_id==send_id:
self.cancel_file(chunk_id, message)
return
filelog.error("Error: cannot cancel download %s, entry not found!", u(send_id))

def cancel_file(self, chunk_id:int, message:str, chunk:int=0):
def cancel_file(self, chunk_id:int, message:str, chunk:int=0) -> None:
filelog("cancel_file%s", (chunk_id, message, chunk))
chunk_state = self.receive_chunks_in_progress.get(chunk_id)
if chunk_state:
Expand All @@ -366,7 +366,7 @@ def clean_receive_state():
filelog.error(f" {filename!r} : {e}")
self.send("ack-file-chunk", chunk_id, False, message, chunk)

def _process_send_file_chunk(self, packet):
def _process_send_file_chunk(self, packet) -> None:
chunk_id, chunk, file_data, has_more = packet[1:5]
chunk_id = net_utf8(chunk_id)
#if len(file_data)<1024:
Expand Down Expand Up @@ -452,7 +452,7 @@ def progress(position, error=None):
self.process_downloaded_file(filename, chunk_state.mimetype,
chunk_state.printit, chunk_state.openit, chunk_state.filesize, options)

def accept_data(self, send_id:str, dtype, basefilename:str, printit:bool, openit:bool):
def accept_data(self, send_id:str, dtype, basefilename:str, printit:bool, openit:bool) -> Tuple[bool,bool]:
#subclasses should check the flags,
#and if ask is True, verify they have accepted this specific send_id
filelog("accept_data%s", (send_id, dtype, basefilename, printit, openit))
Expand All @@ -477,7 +477,7 @@ def accept_data(self, send_id:str, dtype, basefilename:str, printit:bool, openit
openit = False
return (printit, openit)

def _process_send_file(self, packet):
def _process_send_file(self, packet) -> None:
#the remote end is sending us a file
start = monotonic()
basefilename, mimetype, printit, openit, filesize, file_data, options = packet[1:8]
Expand Down Expand Up @@ -569,7 +569,7 @@ def _process_send_file(self, packet):
self.process_downloaded_file(filename, mimetype, printit, openit, filesize, options)


def process_downloaded_file(self, filename, mimetype, printit, openit, filesize, options):
def process_downloaded_file(self, filename:str, mimetype:str, printit:bool, openit:bool, filesize:int, options) -> None:
filelog.info("downloaded %s bytes to %s file%s:",
filesize, (mimetype or "temporary"), ["", " for printing"][int(printit)])
filelog.info(" '%s'", filename)
Expand Down Expand Up @@ -665,11 +665,11 @@ def get_open_env(self) -> Dict[str,str]:
env["XPRA_XDG_OPEN"] = "1"
return env

def _open_file(self, url:str):
def _open_file(self, url:str) -> None:
filelog("_open_file(%s)", url)
self.exec_open_command(url)

def _open_url(self, url:str):
def _open_url(self, url:str) -> None:
filelog("_open_url(%s)", url)
if POSIX:
#we can't use webbrowser,
Expand All @@ -680,7 +680,7 @@ def _open_url(self, url:str):
import webbrowser #pylint: disable=import-outside-toplevel
webbrowser.open_new_tab(url)

def exec_open_command(self, url:str):
def exec_open_command(self, url:str) -> None:
filelog("exec_open_command(%s)", url)
try:
import shlex #pylint: disable=import-outside-toplevel
Expand All @@ -705,12 +705,12 @@ def open_done(*_args):
cr = getChildReaper()
cr.add_process(proc, f"Open file {url}", command, True, True, open_done)

def file_size_warning(self, action:str, location:str, basefilename:str, filesize:int, limit:int):
def file_size_warning(self, action:str, location:str, basefilename:str, filesize:int, limit:int) -> None:
filelog.warn("Warning: cannot %s the file '%s'", action, basefilename)
filelog.warn(" this file is too large: %sB", std_unit(filesize))
filelog.warn(" the %s file size limit is %sB", location, std_unit(limit))

def check_file_size(self, action:str, filename:str, filesize:int):
def check_file_size(self, action:str, filename:str, filesize:int) -> bool:
basefilename = os.path.basename(filename)
if filesize>self.file_size_limit:
self.file_size_warning(action, "local", basefilename, filesize, self.file_size_limit)
Expand Down
7 changes: 3 additions & 4 deletions xpra/net/libproxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class ProxyResolutionError(RuntimeError):
def __init__(self):
self._pf = _libproxy.px_proxy_factory_new()

def getProxies(self, url):
def getProxies(self, url:str):
"""Given a URL, returns a list of proxies in priority order to be used
to reach that URL.
Expand Down Expand Up @@ -121,14 +121,13 @@ def getProxies(self, url):
i=0
while array[i]:
proxy_bytes = cast(array[i], c_char_p).value
proxies.append(proxy_bytes.decode('utf-8', errors='replace'))
if proxy_bytes:
proxies.append(proxy_bytes.decode('utf-8', errors='replace'))
i += 1

_libproxy.px_proxy_factory_free_proxies(array)

return proxies

def __del__(self):
if _libproxy:
_libproxy.px_proxy_factory_free(self._pf)

13 changes: 7 additions & 6 deletions xpra/net/mdns/avahi_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ def __init__(self, service_type=XPRA_TCP_MDNS_TYPE, mdns_found=None, mdns_add=No
#self.mdns_update = mdns_update
self.server = None

def resolve_error(self, *args):
def resolve_error(self, *args) -> None:
log.error("AvahiListener.resolve_error%s", args)

def service_resolved(self, interface, protocol, name, stype, domain, host, x, address, port, text_array, v):
def service_resolved(self, interface, protocol, name:str, stype:str,
domain:str, host:str, x, address, port:int, text_array, v) -> None:
log("AvahiListener.service_resolved%s",
(interface, protocol, name, stype, domain, host, x, address, port, "..", v))
if self.mdns_add:
Expand All @@ -61,7 +62,7 @@ def service_resolved(self, interface, protocol, name, stype, domain, host, x, ad
nargs = (dbus_to_native(x) for x in (interface, protocol, name, stype, domain, host, address, port, text))
self.mdns_add(*nargs)

def service_found(self, interface, protocol, name, stype, domain, flags):
def service_found(self, interface, protocol, name:str, stype:str, domain:str, flags:int) -> None:
log("service_found%s", (interface, protocol, name, stype, domain, flags))
if flags & avahi.LOOKUP_RESULT_LOCAL:
# local service, skip
Expand All @@ -72,14 +73,14 @@ def service_found(self, interface, protocol, name, stype, domain, flags):
domain, avahi.PROTO_UNSPEC, dbus.UInt32(0),
reply_handler=self.service_resolved, error_handler=self.resolve_error)

def service_removed(self, interface, protocol, name, stype, domain, flags):
def service_removed(self, interface, protocol, name:str, stype:str, domain, flags:int) -> None:
log("service_removed%s", (interface, protocol, name, stype, domain, flags))
if self.mdns_remove:
nargs = (dbus_to_native(x) for x in (interface, protocol, name, stype, domain, flags))
self.mdns_remove(*nargs)


def start(self):
def start(self) -> None:
self.server = dbus.Interface(self.bus.get_object(avahi.DBUS_NAME, '/'), 'org.freedesktop.Avahi.Server')
log("AvahiListener.start() server=%s", self.server)

Expand All @@ -93,7 +94,7 @@ def start(self):
s = self.sbrowser.connect_to_signal("ItemRemove", self.service_removed)
self.signal_match.append(s)

def stop(self):
def stop(self) -> None:
sm = self.signal_match
self.signal_match = []
for s in sm:
Expand Down
24 changes: 12 additions & 12 deletions xpra/net/mdns/avahi_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
log = Logger("network", "mdns")


def get_interface_index(host):
def get_interface_index(host) -> int:
log("get_interface_index(%s)", host)
if host in ("0.0.0.0", "", "*", "::"):
return avahi.IF_UNSPEC
Expand Down Expand Up @@ -52,7 +52,7 @@ class AvahiPublishers:
and to convert the text dict into a TXT string.
"""

def __init__(self, listen_on, service_name, service_type=XPRA_TCP_MDNS_TYPE, text_dict=None):
def __init__(self, listen_on, service_name:str, service_type:str=XPRA_TCP_MDNS_TYPE, text_dict=None):
log("AvahiPublishers%s", (listen_on, service_name, service_type, text_dict))
self.publishers = []
try:
Expand Down Expand Up @@ -93,7 +93,7 @@ def __init__(self, listen_on, service_name, service_type=XPRA_TCP_MDNS_TYPE, tex
service_type, domain="", host=fqdn,
text=txt, interface=iface_index))

def start(self):
def start(self) -> None:
log("avahi:starting: %s", self.publishers)
if not self.publishers:
return
Expand All @@ -105,12 +105,12 @@ def start(self):
log.warn(" to avoid this warning, disable mdns support ")
log.warn(" using the 'mdns=no' option")

def stop(self):
def stop(self) -> None:
log("stopping: %s", self.publishers)
for publisher in self.publishers:
publisher.stop()

def update_txt(self, txt):
def update_txt(self, txt) -> None:
for publisher in self.publishers:
publisher.update_txt(txt)

Expand All @@ -132,7 +132,7 @@ def __init__(self, bus, name, port, stype=XPRA_TCP_MDNS_TYPE, domain="", host=""
self.server = None
self.group = None

def iface(self):
def iface(self) -> str:
if self.interface>0:
return "interface %i" % self.interface
return "all interfaces"
Expand All @@ -143,7 +143,7 @@ def host_str(self) -> str:
def __repr__(self):
return "AvahiPublisher(%s)" % self.host_str()

def start(self):
def start(self) -> bool:
try:
self.server = dbus.Interface(self.bus.get_object(avahi.DBUS_NAME, avahi.DBUS_PATH_SERVER),
avahi.DBUS_INTERFACE_SERVER)
Expand All @@ -157,7 +157,7 @@ def start(self):
self.server.connect_to_signal("StateChanged", self.server_state_changed)
return self.server_state_changed(self.server.GetState())

def server_state_changed(self, state, error=None):
def server_state_changed(self, state, error=None) -> bool:
log("server_state_changed(%s, %s) on %s", state, error, self.server)
if state == avahi.SERVER_COLLISION:
log.error("Error: mdns server name collision")
Expand All @@ -179,7 +179,7 @@ def server_state_changed(self, state, error=None):
log.warn(" for name '%s' and port %i on %s", self.name, self.port, self.iface())
return False

def add_service(self):
def add_service(self) -> None:
if not self.group:
return
try:
Expand Down Expand Up @@ -207,7 +207,7 @@ def add_service(self):
log.error(" %s", x)
self.stop()

def stop(self):
def stop(self) -> None:
group = self.group
log("%s.stop() group=%s", self, group)
if group:
Expand All @@ -220,7 +220,7 @@ def stop(self):
self.server = None


def update_txt(self, txt):
def update_txt(self, txt) -> None:
if not self.server:
log("update_txt(%s) ignored, already stopped", txt)
return
Expand Down Expand Up @@ -257,7 +257,7 @@ def main():
name = "test service"
bus = init_system_bus()
publishers = []
def add(service_type=XPRA_TCP_MDNS_TYPE):
def add(service_type:str=XPRA_TCP_MDNS_TYPE):
publisher = AvahiPublisher(bus, name, port, stype=service_type, host=host, text=("somename=somevalue",))
publishers.append(publisher)
def start():
Expand Down
Loading

0 comments on commit e783ebf

Please sign in to comment.