Skip to content

Commit

Permalink
Simplify by eliminating AsyncGenerator base and generator function. R…
Browse files Browse the repository at this point in the history
…emove any new places enforcing max_results.
  • Loading branch information
moodyjon committed May 20, 2022
1 parent 530f9c7 commit e5e9873
Showing 1 changed file with 30 additions and 56 deletions.
86 changes: 30 additions & 56 deletions lbry/dht/protocol/iterative_find.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from itertools import chain
from collections import defaultdict, OrderedDict
from collections.abc import AsyncGenerator
from collections.abc import AsyncIterator
import typing
import logging
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -72,7 +72,7 @@ def get_shortlist(routing_table: 'TreeRoutingTable', key: bytes,
return shortlist or routing_table.find_close_peers(key)


class IterativeFinder(AsyncGenerator):
class IterativeFinder(AsyncIterator):
def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager',
routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes,
max_results: typing.Optional[int] = constants.K,
Expand All @@ -99,8 +99,6 @@ def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager',
self.iteration_count = 0
self.running = False
self.tasks: typing.List[asyncio.Task] = []
self.generator = None

for peer in get_shortlist(routing_table, key, shortlist):
if peer.node_id:
self._add_active(peer, force=True)
Expand Down Expand Up @@ -154,7 +152,7 @@ async def _handle_probe_result(self, peer: 'KademliaPeer', response: FindRespons
log.warning("misbehaving peer %s:%i returned peer with reserved ip %s:%i", peer.address,
peer.udp_port, address, udp_port)
self.check_result_ready(response)
self._log_state()
self._log_state(reason="check result")

def _reset_closest(self, peer):
if peer in self.active:
Expand All @@ -169,7 +167,7 @@ async def _send_probe(self, peer: 'KademliaPeer'):
except asyncio.CancelledError:
log.debug("%s[%x] cancelled probe",
type(self).__name__, id(self))
return
raise
except ValueError as err:
log.warning(str(err))
self._reset_closest(peer)
Expand Down Expand Up @@ -199,8 +197,6 @@ def _search_round(self):
break
if index > (constants.K + len(self.running_probes)):
break
if self.iteration_count + self.iteration_queue.qsize() >= self.max_results:
break
origin_address = (peer.address, peer.udp_port)
if origin_address in self.exclude:
continue
Expand Down Expand Up @@ -232,76 +228,54 @@ def callback(_):
t.add_done_callback(callback)
self.running_probes[peer] = t

def _log_state(self):
log.debug("%s[%x] [%s] check result: %i active nodes %i contacted %i produced %i queued",
def _log_state(self, reason="?"):
log.debug("%s[%x] [%s] %s: %i active nodes %i contacted %i produced %i queued",
type(self).__name__, id(self), self.key.hex()[:8],
len(self.active), len(self.contacted),
reason, len(self.active), len(self.contacted),
self.iteration_count, self.iteration_queue.qsize())

async def _generator_func(self):
try:
while self.iteration_count < self.max_results:
if self.iteration_count == 0:
result = self.get_initial_result() or await self.iteration_queue.get()
else:
result = await self.iteration_queue.get()
if not result:
# no more results
await self._aclose(reason="no more results")
self.generator = None
return
self.iteration_count += 1
yield result
# reached max_results limit
await self._aclose(reason="max_results reached")
self.generator = None
return
except asyncio.CancelledError:
await self._aclose(reason="cancelled")
self.generator = None
raise
except GeneratorExit:
await self._aclose(reason="generator exit")
self.generator = None
raise

def __aiter__(self):
if self.running:
raise Exception("already running")
self.running = True
self.generator = self._generator_func()
self.loop.call_soon(self._search_round)
return super().__aiter__()
return self

async def __anext__(self) -> typing.List['KademliaPeer']:
return await super().__anext__()

async def asend(self, value):
return await self.generator.asend(value)

async def athrow(self, typ, val=None, tb=None):
return await self.generator.athrow(typ, val, tb)
try:
if self.iteration_count == 0:
result = self.get_initial_result() or await self.iteration_queue.get()
else:
result = await self.iteration_queue.get()
if not result:
raise StopAsyncIteration
self.iteration_count += 1
return result
except asyncio.CancelledError:
await self._aclose(reason="cancelled")
raise
except StopAsyncIteration:
await self._aclose(reason="no more results")
raise

async def _aclose(self, reason="?"):
self.running = False
running_tasks = list(chain(self.tasks, self.running_probes.values()))
for task in running_tasks:
task.cancel()
log.debug("%s[%x] [%s] async close because %s: %i active nodes %i contacted %i produced %i queued",
log.debug("%s[%x] [%s] shutdown because %s: %i active nodes %i contacted %i produced %i queued",
type(self).__name__, id(self), self.key.hex()[:8],
reason, len(self.active), len(self.contacted),
self.iteration_count, self.iteration_queue.qsize())
self.running = False
self.iteration_queue.put_nowait(None)
for task in chain(self.tasks, self.running_probes.values()):
task.cancel()
self.tasks.clear()
self.running_probes.clear()

async def aclose(self):
if self.generator:
await super().aclose()
self.generator = None
if self.running:
await self._aclose(reason="aclose")
log.debug("%s[%x] [%s] async close completed",
type(self).__name__, id(self), self.key.hex()[:8])


class IterativeNodeFinder(IterativeFinder):
def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager',
routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes,
Expand Down

0 comments on commit e5e9873

Please sign in to comment.