Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dni committed Aug 2, 2024
1 parent 3bee031 commit 3405341
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 77 deletions.
40 changes: 18 additions & 22 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import asyncio
from typing import List

from fastapi import APIRouter
from lnbits.db import Database
from lnbits.helpers import template_renderer
from lnbits.tasks import create_permanent_unique_task
from loguru import logger

from .nostr.client.client import NostrClient
from .router import NostrRouter

db = Database("ext_nostrclient")
from .crud import db
from .nostr_client import all_routers, nostr_client
from .tasks import check_relays, init_relays, subscribe_events
from .views import nostrclient_generic_router
from .views_api import nostrclient_api_router

nostrclient_static_files = [
{
Expand All @@ -20,23 +17,11 @@
]

nostrclient_ext: APIRouter = APIRouter(prefix="/nostrclient", tags=["nostrclient"])

nostr_client: NostrClient = NostrClient()

# we keep this in
all_routers: list[NostrRouter] = []
nostrclient_ext.include_router(nostrclient_generic_router)
nostrclient_ext.include_router(nostrclient_api_router)
scheduled_tasks: list[asyncio.Task] = []


def nostr_renderer():
return template_renderer(["nostrclient/templates"])


from .tasks import check_relays, init_relays, subscribe_events # noqa
from .views import * # noqa
from .views_api import * # noqa


async def nostrclient_stop():
for task in scheduled_tasks:
try:
Expand All @@ -55,9 +40,20 @@ async def nostrclient_stop():


def nostrclient_start():
from lnbits.tasks import create_permanent_unique_task

task1 = create_permanent_unique_task("ext_nostrclient_init_relays", init_relays)
task2 = create_permanent_unique_task(
"ext_nostrclient_subscrive_events", subscribe_events
)
task3 = create_permanent_unique_task("ext_nostrclient_check_relays", check_relays)
scheduled_tasks.extend([task1, task2, task3])


__all__ = [
"db",
"nostrclient_ext",
"nostrclient_static_files",
"nostrclient_stop",
"nostrclient_start",
]
9 changes: 6 additions & 3 deletions crud.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import json
from typing import List, Optional
from typing import Optional

from lnbits.db import Database

from . import db
from .models import Config, Relay

db = Database("ext_nostrclient")


async def get_relays() -> List[Relay]:
async def get_relays() -> list[Relay]:
rows = await db.fetchall("SELECT * FROM nostrclient.relays")
return [Relay.from_row(r) for r in rows]

Expand Down
5 changes: 5 additions & 0 deletions nostr_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .nostr.client.client import NostrClient
from .router import NostrRouter

nostr_client: NostrClient = NostrClient()
all_routers: list[NostrRouter] = []
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ build-backend = "poetry.core.masonry.api"

[tool.mypy]
exclude = "(nostr/*)"

[[tool.mypy.overrides]]
module = [
"nostr.*",
"lnbits.*",
"lnurl.*",
"loguru.*",
Expand All @@ -32,7 +34,9 @@ module = [
"pyqrcode.*",
"shortuuid.*",
"httpx.*",
"secp256k1.*",
]
follow_imports = "skip"
ignore_missing_imports = "True"

[tool.pytest.ini_options]
Expand Down
11 changes: 7 additions & 4 deletions router.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@


class NostrRouter:
received_subscription_events: dict[str, List[EventMessage]] = {}
received_subscription_notices: list[NoticeMessage] = []
received_subscription_eosenotices: dict[str, EndOfStoredEventsMessage] = {}
received_subscription_events: dict[str, List[EventMessage]]
received_subscription_notices: list[NoticeMessage]
received_subscription_eosenotices: dict[str, EndOfStoredEventsMessage]

def __init__(self, websocket: WebSocket):
self.connected: bool = True
Expand Down Expand Up @@ -154,7 +154,10 @@ def _handle_client_close(self, subscription_id):
self.original_subscription_ids.pop(subscription_id_rewritten)
nostr_client.relay_manager.close_subscription(subscription_id_rewritten)
logger.info(
f"Unsubscribe from '{subscription_id_rewritten}'. Original id: '{subscription_id}.'"
f"""
Unsubscribe from '{subscription_id_rewritten}'.
Original id: '{subscription_id}.'
"""
)
else:
logger.info(f"Failed to unsubscribe from '{subscription_id}.'")
30 changes: 14 additions & 16 deletions tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ async def init_relays():
# get relays from db
relays = await get_relays()
# set relays and connect to them
valid_relays = list(set([r.url for r in relays if r.url]))
valid_relays = [r.url for r in relays if r.url]

nostr_client.reconnect(valid_relays)

Expand All @@ -29,34 +29,32 @@ async def check_relays():


async def subscribe_events():
while not any([r.connected for r in nostr_client.relay_manager.relays.values()]):
while not [r.connected for r in nostr_client.relay_manager.relays.values()]:
await asyncio.sleep(2)

def callback_events(eventMessage: EventMessage):
sub_id = eventMessage.subscription_id
def callback_events(event_message: EventMessage):
sub_id = event_message.subscription_id
if sub_id not in NostrRouter.received_subscription_events:
NostrRouter.received_subscription_events[sub_id] = [eventMessage]
NostrRouter.received_subscription_events[sub_id] = [event_message]
return

# do not add duplicate events (by event id)
ids = set(
[e.event_id for e in NostrRouter.received_subscription_events[sub_id]]
)
if eventMessage.event_id in ids:
ids = [e.event_id for e in NostrRouter.received_subscription_events[sub_id]]
if event_message.event_id in ids:
return

NostrRouter.received_subscription_events[sub_id].append(eventMessage)
NostrRouter.received_subscription_events[sub_id].append(event_message)

def callback_notices(noticeMessage: NoticeMessage):
if noticeMessage not in NostrRouter.received_subscription_notices:
NostrRouter.received_subscription_notices.append(noticeMessage)
def callback_notices(notice_message: NoticeMessage):
if notice_message not in NostrRouter.received_subscription_notices:
NostrRouter.received_subscription_notices.append(notice_message)

def callback_eose_notices(eventMessage: EndOfStoredEventsMessage):
sub_id = eventMessage.subscription_id
def callback_eose_notices(event_message: EndOfStoredEventsMessage):
sub_id = event_message.subscription_id
if sub_id in NostrRouter.received_subscription_eosenotices:
return

NostrRouter.received_subscription_eosenotices[sub_id] = eventMessage
NostrRouter.received_subscription_eosenotices[sub_id] = event_message

def wrap_async_subscribe():
asyncio.run(
Expand Down
15 changes: 10 additions & 5 deletions views.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from fastapi import Depends, Request
from fastapi import APIRouter, Depends, Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from lnbits.core.models import User
from lnbits.decorators import check_admin
from starlette.responses import HTMLResponse

from . import nostr_renderer, nostrclient_ext
from lnbits.helpers import template_renderer

templates = Jinja2Templates(directory="templates")

nostrclient_generic_router = APIRouter()


def nostr_renderer():
return template_renderer(["nostrclient/templates"])


@nostrclient_ext.get("/", response_class=HTMLResponse)
@nostrclient_generic_router.get("/", response_class=HTMLResponse)
async def index(request: Request, user: User = Depends(check_admin)):
return nostr_renderer().TemplateResponse(
"nostrclient/index.html", {"request": request, "user": user.dict()}
Expand Down
55 changes: 28 additions & 27 deletions views_api.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import asyncio
from http import HTTPStatus
from typing import List

from fastapi import Depends, WebSocket
from fastapi import APIRouter, Depends, HTTPException, WebSocket
from lnbits.decorators import check_admin
from lnbits.helpers import decrypt_internal_message, urlsafe_short_hash
from loguru import logger
from starlette.exceptions import HTTPException

from . import all_routers, nostr_client, nostrclient_ext
from .crud import (
add_relay,
create_config,
Expand All @@ -18,13 +15,16 @@
update_config,
)
from .helpers import normalize_public_key
from .models import Config, Relay, TestMessage, TestMessageResponse
from .models import Config, Relay, RelayStatus, TestMessage, TestMessageResponse
from .nostr.key import EncryptedDirectMessage, PrivateKey
from .nostr_client import all_routers, nostr_client
from .router import NostrRouter

nostrclient_api_router = APIRouter()

@nostrclient_ext.get("/api/v1/relays", dependencies=[Depends(check_admin)])
async def api_get_relays() -> List[Relay]:

@nostrclient_api_router.get("/api/v1/relays", dependencies=[Depends(check_admin)])
async def api_get_relays() -> list[Relay]:
relays = []
for url, r in nostr_client.relay_manager.relays.items():
relay_id = urlsafe_short_hash()
Expand All @@ -33,24 +33,24 @@ async def api_get_relays() -> List[Relay]:
id=relay_id,
url=url,
connected=r.connected,
status={
"num_sent_events": r.num_sent_events,
"num_received_events": r.num_received_events,
"error_counter": r.error_counter,
"error_list": r.error_list,
"notice_list": r.notice_list,
},
status=RelayStatus(
num_sent_events=r.num_sent_events,
num_received_events=r.num_received_events,
error_counter=r.error_counter,
error_list=r.error_list,
notice_list=r.notice_list,
),
ping=r.ping,
active=True,
)
)
return relays


@nostrclient_ext.post(
@nostrclient_api_router.post(
"/api/v1/relay", status_code=HTTPStatus.OK, dependencies=[Depends(check_admin)]
)
async def api_add_relay(relay: Relay) -> List[Relay]:
async def api_add_relay(relay: Relay) -> list[Relay]:
if not relay.url:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, detail="Relay url not provided."
Expand All @@ -68,7 +68,7 @@ async def api_add_relay(relay: Relay) -> List[Relay]:
return await get_relays()


@nostrclient_ext.delete(
@nostrclient_api_router.delete(
"/api/v1/relay", status_code=HTTPStatus.OK, dependencies=[Depends(check_admin)]
)
async def api_delete_relay(relay: Relay) -> None:
Expand All @@ -81,7 +81,7 @@ async def api_delete_relay(relay: Relay) -> None:
await delete_relay(relay)


@nostrclient_ext.put(
@nostrclient_api_router.put(
"/api/v1/relay/test", status_code=HTTPStatus.OK, dependencies=[Depends(check_admin)]
)
async def api_test_endpoint(data: TestMessage) -> TestMessageResponse:
Expand All @@ -105,33 +105,34 @@ async def api_test_endpoint(data: TestMessage) -> TestMessageResponse:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=str(ex),
)
) from ex
except Exception as ex:
logger.warning(ex)
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="Cannot generate test event",
)
) from ex


@nostrclient_ext.websocket("/api/v1/{id}")
async def ws_relay(id: str, websocket: WebSocket) -> None:
@nostrclient_api_router.websocket("/api/v1/{id}")
async def ws_relay(ws_id: str, websocket: WebSocket) -> None:
"""Relay multiplexer: one client (per endpoint) <-> multiple relays"""

logger.info("New websocket connection at: '/api/v1/relay'")
try:
config = await get_config()
assert config, "Failed to get config"

if not config.private_ws and not config.public_ws:
raise ValueError("Websocket connections not accepted.")

if id == "relay":
if ws_id == "relay":
if not config.public_ws:
raise ValueError("Public websocket connections not accepted.")
else:
if not config.private_ws:
raise ValueError("Private websocket connections not accepted.")
if decrypt_internal_message(id) != "relay":
if decrypt_internal_message(ws_id) != "relay":
raise ValueError("Invalid websocket endpoint.")

await websocket.accept()
Expand Down Expand Up @@ -160,10 +161,10 @@ async def ws_relay(id: str, websocket: WebSocket) -> None:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="Cannot accept websocket connection",
)
) from ex


@nostrclient_ext.get("/api/v1/config", dependencies=[Depends(check_admin)])
@nostrclient_api_router.get("/api/v1/config", dependencies=[Depends(check_admin)])
async def api_get_config() -> Config:
config = await get_config()
if not config:
Expand All @@ -172,7 +173,7 @@ async def api_get_config() -> Config:
return config


@nostrclient_ext.put("/api/v1/config", dependencies=[Depends(check_admin)])
@nostrclient_api_router.put("/api/v1/config", dependencies=[Depends(check_admin)])
async def api_update_config(data: Config):
config = await update_config(data)
assert config
Expand Down

0 comments on commit 3405341

Please sign in to comment.