Skip to content

Commit

Permalink
Merge pull request #24 from mic1on/feat_base_worker
Browse files Browse the repository at this point in the history
✨ feat: Support assignment worker class
  • Loading branch information
mic1on authored Nov 23, 2023
2 parents ced1b58 + 9078332 commit 7587c69
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 73 deletions.
18 changes: 18 additions & 0 deletions example/example_use_threadpool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from onestep import step, CronBroker
from onestep.worker import ThreadPoolWorker

cron_broker = CronBroker("* * * * * */3", body="hi cron")


@step(from_broker=cron_broker,
workers=3,
worker_class=ThreadPoolWorker)
def cron_task(message):
print(message)
return message


if __name__ == '__main__':
step.set_debugging()
step.start(block=True)
# step.shutdown()
23 changes: 18 additions & 5 deletions src/onestep/onestep.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
import functools
import collections
import inspect
import logging

from inspect import isgenerator, iscoroutinefunction, isasyncgenfunction, isasyncgen
from typing import Optional, List, Dict, Any, Callable, Union
from typing import Optional, List, Dict, Any, Callable, Union, Type

from .broker.base import BaseBroker
from .exception import StopMiddleware
from .message import Message
from .retry import TimesRetry
from .signal import message_sent, started, stopped
from .state import State
from .worker import WorkerThread
from .worker import ThreadWorker, BaseWorker

logger = logging.getLogger(__name__)

DEFAULT_WORKERS = 1
DEFAULT_WORKER_CLASS = ThreadWorker


class BaseOneStep:
consumers: Dict[str, List[WorkerThread]] = collections.defaultdict(list)
consumers: Dict[str, List[BaseWorker]] = collections.defaultdict(list)
state = State() # 全局状态

def __init__(self, fn,
Expand All @@ -28,13 +30,15 @@ def __init__(self, fn,
from_broker: Union[BaseBroker, List[BaseBroker], None] = None,
to_broker: Union[BaseBroker, List[BaseBroker], None] = None,
workers: Optional[int] = None,
worker_class: Optional[Type[BaseWorker]] = None,
middlewares: Optional[List[Any]] = None,
retry: Union[Callable, object] = TimesRetry(),
error_callback: Optional[Union[Callable, object]] = None):
self.group = group
self.fn = fn
self.name = name or fn.__name__
self.workers = workers or DEFAULT_WORKERS
self.worker_class = worker_class or DEFAULT_WORKER_CLASS
self.middlewares = middlewares or []

self.from_brokers = self._init_broker(from_broker)
Expand All @@ -59,10 +63,16 @@ def _init_broker(broker: Union[BaseBroker, List[BaseBroker], None] = None):

def _add_consumer(self, broker):
""" 添加来源消费者 """
for _ in range(self.workers):
worker_class_params = inspect.signature(self.worker_class.__init__).parameters
if "workers" in worker_class_params:
self.consumers[self.group].append(
WorkerThread(onestep=self, broker=broker)
self.worker_class(onestep=self, broker=broker, workers=self.workers)
)
else:
for _ in range(self.workers):
self.consumers[self.group].append(
self.worker_class(onestep=self, broker=broker)
)

@classmethod
def _find_consumers(cls, group: Optional[str] = None):
Expand Down Expand Up @@ -182,11 +192,13 @@ async def __call__(self, *args, **kwargs):
class step:

def __init__(self,
*,
group: str = "OneStep",
name: str = None,
from_broker: Union[BaseBroker, List[BaseBroker], None] = None,
to_broker: Union[BaseBroker, List[BaseBroker], None] = None,
workers: Optional[int] = None,
worker_class: Optional[Type[BaseWorker]] = None,
middlewares: Optional[List[Any]] = None,
retry: Union[Callable, object] = TimesRetry(),
error_callback: Optional[Union[Callable, object]] = None):
Expand All @@ -196,6 +208,7 @@ def __init__(self,
"from_broker": from_broker,
"to_broker": to_broker,
"workers": workers,
"worker_class": worker_class,
"middlewares": middlewares,
"retry": retry,
"error_callback": error_callback
Expand Down
190 changes: 124 additions & 66 deletions src/onestep/worker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
将指定的函数放入线程中运行
"""
from concurrent.futures import ThreadPoolExecutor
from typing import Dict

try:
Expand All @@ -22,16 +23,11 @@
logger = logging.getLogger(__name__)


class WorkerThread(threading.Thread):
class BaseWorker:
broker_exit: Dict[BaseBroker, bool] = {}
broker_exit_lock = threading.Lock()

def __init__(self, onestep, broker: BaseBroker, *args, **kwargs):
"""
线程执行包装过的`onestep`函数
:param onestep: OneStep实例
:param broker: 监听的from broker
"""
super().__init__(daemon=True)
self.instance = onestep
self.retry = self.instance.retry
self.error_callback = self.instance.error_callback
Expand All @@ -40,6 +36,92 @@ def __init__(self, onestep, broker: BaseBroker, *args, **kwargs):
self.kwargs = kwargs
self._shutdown = False

def start(self):
"""启动 Worker"""
raise NotImplementedError

def run(self):
"""执行 Worker 的逻辑"""
raise NotImplementedError

def shutdown(self):
"""关闭 Worker"""
raise NotImplementedError

def _receive_message(self):
for result in self.broker.consume():
if self._shutdown:
break
if result is None:
continue
messages = (
result
if isinstance(result, Iterable)
else [result]
)
for message in messages:
message.broker = message.broker or self.broker
logger.debug(f"{self.instance.name} receive message<{message}> from {self.broker!r}")
message_received.send(self, message=message)
try:
self.instance.before_emit("consume", message=message)
self._run_instance(message)
self.instance.after_emit("consume", message=message)
except DropMessage as e:
message_drop.send(self, message=message, reason=e)
logger.warning(f"{self.instance.name} dropped <{type(e).__name__}: {str(e)}>")
message.reject()
finally:
# When message is triggered by cancel_consume, it will be shutdown
if self.broker.cancel_consume and self.broker.cancel_consume(message):
self.shutdown()
else:
if self.broker.once:
self.shutdown()

def _run_instance(self, message):
"""执行实例的逻辑"""
try:
if iscoroutinefunction(self.instance.fn) or isasyncgenfunction(self.instance.fn):
async_to_sync(self.instance)(message, *self.args, **self.kwargs)
else:
self.instance(message, *self.args, **self.kwargs)
message_consumed.send(self, message=message)
message.confirm()
except Exception as e:
message_error.send(self, message=message, error=e)
if self.instance.state.debug:
logger.exception(f"{self.instance.name} run error <{type(e).__name__}: {str(e)}>")
else:
logger.error(f"{self.instance.name} run error <{type(e).__name__}: {str(e)}>")
message.set_exception()

retry_status = self.retry(message)
if retry_status is RetryStatus.END_WITH_CALLBACK:
if self.error_callback:
self.error_callback(message)
message.reject()
elif retry_status is RetryStatus.END_IGNORE_CALLBACK:
# 由于是队列内重试,不会触发错误回调
message.requeue()


class ThreadWorker(BaseWorker):

def __init__(self, onestep, broker: BaseBroker, *args, **kwargs):
"""
线程执行包装过的`onestep`函数
:param onestep: OneStep实例
:param broker: 监听的from broker
"""
super().__init__(onestep, broker, *args, **kwargs)
self.thread = None

def start(self):
"""启动单线程 Worker"""
self.thread = threading.Thread(target=self.run, daemon=True)
self.thread.start()

def run(self):
"""线程执行包装过的`onestep`函数
Expand All @@ -48,68 +130,44 @@ def run(self):
"""

while not self._shutdown:
if WorkerThread.broker_exit.get(self.broker, False):
self.shutdown()
break
for result in self.broker.consume():
if self._shutdown:
with ThreadWorker.broker_exit_lock:
if ThreadWorker.broker_exit.get(self.broker, False):
self.shutdown()
break
if result is None:
continue
messages = (
result
if isinstance(result, Iterable)
else [result]
)
for message in messages:
message.broker = message.broker or self.broker
logger.debug(f"{self.instance.name} receive message<{message}> from {self.broker!r}")
message_received.send(self, message=message)
try:
self.instance.before_emit("consume", message=message)
self._run_instance(message)
self.instance.after_emit("consume", message=message)
except DropMessage as e:
message_drop.send(self, message=message, reason=e)
logger.warning(f"{self.instance.name} dropped <{type(e).__name__}: {str(e)}>")
message.reject()
finally:
# When message is triggered by cancel_consume, it will be shutdown
if self.broker.cancel_consume and self.broker.cancel_consume(message):
self.shutdown()
else:
if self.broker.once:
self.shutdown()
self._receive_message()

def shutdown(self):
WorkerThread.broker_exit[self.broker] = True
ThreadWorker.broker_exit[self.broker] = True
self.broker.shutdown()
self._shutdown = True

def _run_instance(self, message):
while True:
try:
if iscoroutinefunction(self.instance.fn) or isasyncgenfunction(self.instance.fn):
async_to_sync(self.instance)(message, *self.args, **self.kwargs)
else:
self.instance(message, *self.args, **self.kwargs)
message_consumed.send(self, message=message)
return message.confirm()
except Exception as e:
message_error.send(self, message=message, error=e)
if self.instance.state.debug:
logger.exception(f"{self.instance.name} run error <{type(e).__name__}: {str(e)}>")
else:
logger.error(f"{self.instance.name} run error <{type(e).__name__}: {str(e)}>")
message.set_exception()

retry_status = self.retry(message)
if retry_status is RetryStatus.CONTINUE:
continue
elif retry_status is RetryStatus.END_WITH_CALLBACK:
if self.error_callback:
self.error_callback(message)
return message.reject()
else: # RetryStatus.END_IGNORE_CALLBACK
# 由于是队列内重试,不会触发错误回调
return message.requeue()

class ThreadPoolWorker(BaseWorker):

def __init__(self, onestep, broker: BaseBroker, workers=None, *args, **kwargs):
super().__init__(onestep, broker, *args, **kwargs)
self.executor = ThreadPoolExecutor(max_workers=workers)

def start(self):
"""启动线程池 Worker"""
self.executor.submit(self.run)

def run(self):
"""线程执行包装过的`onestep`函数
`fn`为`onestep`函数,执行会调用`onestep`的`__call__`方法
:return:
"""

while not self._shutdown:
with ThreadPoolWorker.broker_exit_lock:
if ThreadPoolWorker.broker_exit.get(self.broker, False):
self.shutdown()
break
self._receive_message()

def shutdown(self):
"""关闭线程池 Worker"""
ThreadPoolWorker.broker_exit[self.broker] = True
self._shutdown = True
self.executor.shutdown()
6 changes: 4 additions & 2 deletions tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from onestep import MemoryBroker
from onestep.onestep import BaseOneStep
from onestep.worker import WorkerThread
from onestep.worker import ThreadWorker


@pytest.fixture
Expand All @@ -13,7 +13,9 @@ def broker():
@pytest.fixture
def worker_thread(broker):
onestep = BaseOneStep(fn=lambda message: message)
return WorkerThread(onestep, broker)
wt = ThreadWorker(onestep, broker)
wt.start()
yield wt


def test_shutdown(worker_thread):
Expand Down

0 comments on commit 7587c69

Please sign in to comment.