diff --git a/__init__.py b/__init__.py index 0d837c8..bd22178 100644 --- a/__init__.py +++ b/__init__.py @@ -8,7 +8,7 @@ from .views import satspay_generic_router from .views_api import satspay_api_router from .views_api_themes import satspay_theme_router -from .websocket_handler import websocket_handler +from .websocket_handler import restart_websocket_task, websocket_task satspay_ext: APIRouter = APIRouter(prefix="/satspay", tags=["satspay"]) satspay_ext.include_router(satspay_generic_router) @@ -31,6 +31,8 @@ def satspay_stop(): task.cancel() except Exception as ex: logger.warning(ex) + if websocket_task: + websocket_task.cancel() def satspay_start(): @@ -39,11 +41,9 @@ def satspay_start(): paid_invoices_task = create_permanent_unique_task( "ext_satspay_paid_invoices", wait_for_paid_invoices ) - websocket_task = create_permanent_unique_task( - "ext_satspay_websocket", websocket_handler - ) onchain_task = create_permanent_unique_task("ext_satspay_onchain", wait_for_onchain) - scheduled_tasks.extend([paid_invoices_task, websocket_task, onchain_task]) + scheduled_tasks.extend([paid_invoices_task, onchain_task]) + restart_websocket_task() __all__ = ["db", "satspay_ext", "satspay_static_files", "satspay_start", "satspay_stop"] diff --git a/views_api.py b/views_api.py index eebee66..fed0851 100644 --- a/views_api.py +++ b/views_api.py @@ -28,6 +28,7 @@ ) from .models import Charge, CreateCharge, SatspaySettings from .tasks import start_onchain_listener, stop_onchain_listener +from .websocket_handler import restart_websocket_task satspay_api_router = APIRouter() @@ -168,9 +169,12 @@ async def api_get_or_create_settings() -> SatspaySettings: @satspay_api_router.put("/api/v1/settings", dependencies=[Depends(check_admin)]) async def api_update_settings(data: SatspaySettings) -> SatspaySettings: - return await update_satspay_settings(data) + settings = await update_satspay_settings(data) + restart_websocket_task() + return settings @satspay_api_router.delete("/api/v1/settings", dependencies=[Depends(check_admin)]) async def api_delete_settings() -> None: await delete_satspay_settings() + restart_websocket_task() diff --git a/websocket_handler.py b/websocket_handler.py index 2d921be..ccf2735 100644 --- a/websocket_handler.py +++ b/websocket_handler.py @@ -1,7 +1,9 @@ import asyncio import json +from typing import Optional from lnbits.settings import settings +from lnbits.tasks import create_permanent_unique_task from loguru import logger from websockets.client import connect @@ -10,6 +12,18 @@ ws_receive_queue: asyncio.Queue[dict] = asyncio.Queue() ws_send_queue: asyncio.Queue[dict] = asyncio.Queue() +websocket_task: Optional[asyncio.Task] = None + + +def restart_websocket_task(): + logger.info("Restarting websocket task...") + global websocket_task + if websocket_task: + websocket_task.cancel() + websocket_task = create_permanent_unique_task( + "ext_satspay_websocket", websocket_handler + ) + async def consumer_handler(websocket): async for message in websocket: