From 953c93e43dd9cda3b8d1458ccf350c5ac6de3dc2 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Sun, 4 Apr 2021 02:14:29 +0200 Subject: [PATCH] Reuse SSL context in host browser (#485) --- pychromecast/dial.py | 36 ++++++++++++++++++++++++++++-------- pychromecast/discovery.py | 8 +++++--- 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/pychromecast/dial.py b/pychromecast/dial.py index f508fc82c..f781e558c 100644 --- a/pychromecast/dial.py +++ b/pychromecast/dial.py @@ -58,7 +58,7 @@ def _get_host_from_zc_service_info(service_info: zeroconf.ServiceInfo): return (host, port) -def _get_status(host, services, zconf, path, secure, timeout): +def _get_status(host, services, zconf, path, secure, timeout, context): """ :param host: Hostname or ip to fetch status from :type host: str @@ -75,21 +75,29 @@ def _get_status(host, services, zconf, path, secure, timeout): headers = {"content-type": "application/json"} - context = None if secure: url = FORMAT_BASE_URL_HTTPS.format(host) + path - context = ssl.SSLContext() - context.verify_mode = ssl.CERT_NONE else: url = FORMAT_BASE_URL_HTTP.format(host) + path + has_context = bool(context) + if secure and not has_context: + context = get_ssl_context() + req = urllib.request.Request(url, headers=headers) with urllib.request.urlopen(req, timeout=timeout, context=context) as response: data = response.read() return json.loads(data.decode("utf-8")) -def get_device_status(host, services=None, zconf=None, timeout=10): +def get_ssl_context(): + """Create an SSL context.""" + context = ssl.SSLContext() + context.verify_mode = ssl.CERT_NONE + return context + + +def get_device_status(host, services=None, zconf=None, timeout=10, context=None): """ :param host: Hostname or ip to fetch status from :type host: str @@ -99,7 +107,13 @@ def get_device_status(host, services=None, zconf=None, timeout=10): try: status = _get_status( - host, services, zconf, "/setup/eureka_info?options=detail", True, timeout + host, + services, + zconf, + "/setup/eureka_info?options=detail", + True, + timeout, + context, ) friendly_name = status.get("name", "Unknown Chromecast") @@ -144,7 +158,7 @@ def _get_group_info(host, group): return MultizoneInfo(name, uuid, leader_host, leader_port) -def get_multizone_status(host, services=None, zconf=None, timeout=10): +def get_multizone_status(host, services=None, zconf=None, timeout=10, context=None): """ :param host: Hostname or ip to fetch status from :type host: str @@ -154,7 +168,13 @@ def get_multizone_status(host, services=None, zconf=None, timeout=10): try: status = _get_status( - host, services, zconf, "/setup/eureka_info?params=multizone", True, timeout + host, + services, + zconf, + "/setup/eureka_info?params=multizone", + True, + timeout, + context, ) dynamic_groups = [] diff --git a/pychromecast/discovery.py b/pychromecast/discovery.py index 13e40e5f0..920267fcd 100644 --- a/pychromecast/discovery.py +++ b/pychromecast/discovery.py @@ -11,7 +11,7 @@ import zeroconf from .const import SERVICE_TYPE_HOST, SERVICE_TYPE_MDNS -from .dial import get_device_status, get_multizone_status +from .dial import get_device_status, get_multizone_status, get_ssl_context DISCOVER_TIMEOUT = 5 @@ -206,6 +206,7 @@ def __init__(self, cast_listener, devices, lock): self._next_update = time.time() self._services_lock = lock self._start_requested = False + self._context = None self.stop = threading.Event() def add_hosts(self, known_hosts): @@ -234,6 +235,7 @@ def update_hosts(self, known_hosts): def run(self): """Start worker thread.""" _LOGGER.debug("HostBrowser thread started") + self._context = get_ssl_context() try: while not self.stop.is_set(): self._poll_hosts() @@ -252,7 +254,7 @@ def _poll_hosts(self): uuids = [] if self.stop.is_set(): break - device_status = get_device_status(host, timeout=4) + device_status = get_device_status(host, timeout=4, context=self._context) try: hoststatus = self._known_hosts[host] except KeyError: @@ -280,7 +282,7 @@ def _poll_hosts(self): ) uuids.append(device_status.uuid) - multizone_status = get_multizone_status(host) + multizone_status = get_multizone_status(host, context=self._context) if not multizone_status: return