Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to use the NWC versioning system. #203

Merged
merged 1 commit into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 96 additions & 5 deletions nwc_backend/event_handlers/__tests__/nip47_event_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,12 @@ def create_request_event(
self,
method: Nip47RequestMethod = Nip47RequestMethod.PAY_INVOICE,
params: Optional[dict[str, Any]] = None,
version: Optional[str] = "1.0",
use_nip44: bool = True,
) -> Event:
if params is None:
params = self.get_default_request_params()
return (
builder = (
EventBuilder(
kind=KindEnum.WALLET_CONNECT_REQUEST(), # pyre-ignore[6]
content=json.dumps(
Expand All @@ -70,8 +71,10 @@ def create_request_event(
)
.encrypt_content(self.nwc_keys.public_key(), use_nip44=use_nip44)
.add_tag(["p", self.nwc_keys.public_key().to_hex()])
.build()
)
if version:
builder.add_tag(["v", version])
return builder.build()

def get_default_request_params(self) -> dict[str, Any]:
return {
Expand Down Expand Up @@ -191,7 +194,10 @@ async def test_failed__invalid_input_params(
granted_permissions_groups=[PermissionsGroup.SEND_PAYMENTS],
keys=harness.client_app_keys,
)
request_event = harness.create_request_event(params={}, use_nip44=use_nip44)
version = "1.0" if use_nip44 else None
request_event = harness.create_request_event(
params={}, use_nip44=use_nip44, version=version
)
await handle_nip47_event(request_event)

mock_nostr_send.assert_called_once()
Expand Down Expand Up @@ -232,7 +238,10 @@ async def test_succeeded(
granted_permissions_groups=[PermissionsGroup.SEND_PAYMENTS],
keys=harness.client_app_keys,
)
request_event = harness.create_request_event(use_nip44=use_nip44)
version = "1.0" if use_nip44 else None
request_event = harness.create_request_event(
use_nip44=use_nip44, version=version
)
await handle_nip47_event(request_event)

mock_nostr_send.assert_called_once()
Expand Down Expand Up @@ -289,7 +298,10 @@ async def test_failed__vasp_error_response(
granted_permissions_groups=[PermissionsGroup.SEND_PAYMENTS],
keys=harness.client_app_keys,
)
request_event = harness.create_request_event(use_nip44=use_nip44)
version = "1.0" if use_nip44 else None
request_event = harness.create_request_event(
use_nip44=use_nip44, version=version
)
await handle_nip47_event(request_event)

mock_nostr_send.assert_called_once()
Expand Down Expand Up @@ -327,3 +339,82 @@ async def test_duplicate_event(
result = await db.session.execute(select(Nip47Request))
request = result.scalars().one()
assert request.id == nip47_event.id


@patch("nwc_backend.nostr.nostr_client.nostr_client.send_event", new_callable=AsyncMock)
async def test_failed__invalid_version(
mock_nostr_send: AsyncMock,
test_client: QuartClient,
) -> None:
mock_nostr_send.return_value = SendEventOutput(
id=EventId.from_hex(token_hex()),
output=Output(success=["wss://relay.getalby.com/v1"], failed={}),
)
async with test_client.app.app_context():
harness = Harness.prepare()
await create_nwc_connection(
granted_permissions_groups=[PermissionsGroup.SEND_PAYMENTS],
keys=harness.client_app_keys,
)
request_event = harness.create_request_event(params={}, version="abc")
await handle_nip47_event(request_event)

mock_nostr_send.assert_called_once()
response_event = mock_nostr_send.call_args[0][0]
content = harness.validate_response_event(response_event, request_event.id())
assert content["result_type"] == Nip47RequestMethod.PAY_INVOICE.value
assert content["error"]["code"] == ErrorCode.OTHER.name


@patch("nwc_backend.nostr.nostr_client.nostr_client.send_event", new_callable=AsyncMock)
async def test_failed__unsupported(
mock_nostr_send: AsyncMock,
test_client: QuartClient,
) -> None:
mock_nostr_send.return_value = SendEventOutput(
id=EventId.from_hex(token_hex()),
output=Output(success=["wss://relay.getalby.com/v1"], failed={}),
)
async with test_client.app.app_context():
harness = Harness.prepare()
await create_nwc_connection(
granted_permissions_groups=[PermissionsGroup.SEND_PAYMENTS],
keys=harness.client_app_keys,
)
request_event = harness.create_request_event(params={}, version="10.0")
await handle_nip47_event(request_event)

mock_nostr_send.assert_called_once()
response_event = mock_nostr_send.call_args[0][0]
content = harness.validate_response_event(response_event, request_event.id())
assert content["result_type"] == Nip47RequestMethod.PAY_INVOICE.value
assert content["error"]["code"] == ErrorCode.NOT_IMPLEMENTED.name


@patch("nwc_backend.nostr.nostr_client.nostr_client.send_event", new_callable=AsyncMock)
async def test_failed__wrong_encryption_for_version(
mock_nostr_send: AsyncMock,
test_client: QuartClient,
) -> None:
mock_nostr_send.return_value = SendEventOutput(
id=EventId.from_hex(token_hex()),
output=Output(success=["wss://relay.getalby.com/v1"], failed={}),
)
async with test_client.app.app_context():
harness = Harness.prepare()
await create_nwc_connection(
granted_permissions_groups=[PermissionsGroup.SEND_PAYMENTS],
keys=harness.client_app_keys,
)
request_event = harness.create_request_event(
params={}, version="1.0", use_nip44=False
)
await handle_nip47_event(request_event)

mock_nostr_send.assert_called_once()
response_event = mock_nostr_send.call_args[0][0]
content = harness.validate_response_event(
response_event, request_event.id(), expect_nip44=False
)
assert content["result_type"] == Nip47RequestMethod.PAY_INVOICE.value
assert content["error"]["code"] == ErrorCode.OTHER.name
47 changes: 47 additions & 0 deletions nwc_backend/event_handlers/nip47_event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from nwc_backend.models.nwc_connection import NWCConnection
from nwc_backend.nostr.nostr_client import nostr_client
from nwc_backend.nostr.nostr_config import NostrConfig
from nwc_backend.nostr.versions import ParsedVersion, is_version_supported


async def handle_nip47_event(event: Event) -> None:
Expand Down Expand Up @@ -77,6 +78,22 @@ async def handle_nip47_event(event: Event) -> None:
return

method = Nip47RequestMethod(content["method"])

try:
_check_version(event)
except Nip47RequestException as ex:
error_response = create_nip47_error_response(
event=event,
method=method,
error=Nip47Error(
code=ex.error_code,
message=ex.error_message,
),
use_nip44=not is_nip04_encrypted,
)
await nostr_client.send_event(error_response)
return

if not nwc_connection.has_command_permission(method):
error_response = create_nip47_error_response(
event=event,
Expand Down Expand Up @@ -187,3 +204,33 @@ async def handle_nip47_event(event: Event) -> None:
await nip47_request.update_response_and_save(
response_event_id=output.id.to_hex(), response=response
)


def _check_version(event: Event) -> ParsedVersion:
is_nip04_encrypted = "?iv=" in event.content()
selected_version = ParsedVersion(0, 0)
version_tag = next((tag for tag in event.tags() if tag.as_vec()[0] == "v"), None)
if version_tag:
selected_version_str = version_tag.content() or "0.0"
try:
selected_version = ParsedVersion.load(selected_version_str)
except ValueError:
raise Nip47RequestException(
error_code=ErrorCode.OTHER,
error_message=f"Invalid version {selected_version_str}.",
)

if not is_version_supported(selected_version):
raise Nip47RequestException(
# TODO: Use ErrorCode.VERSION_NOT_SUPPORTED when added.
error_code=ErrorCode.NOT_IMPLEMENTED,
error_message=f"Unsupported version {selected_version}.",
)

if selected_version.major > 0 and is_nip04_encrypted:
raise Nip47RequestException(
error_code=ErrorCode.OTHER,
error_message="NIP04 encryption is not supported for version > 0. Please use NIP44.",
)

return selected_version
13 changes: 9 additions & 4 deletions nwc_backend/nostr/nostr_client_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from nwc_backend.models.nip47_request_method import Nip47RequestMethod
from nwc_backend.nostr.nostr_client import nostr_client
from nwc_backend.nostr.nostr_config import NostrConfig
from nwc_backend.nostr.versions import NWC_VERSIONS_SUPPORTED


class NotificationHandler(HandleNotification):
Expand Down Expand Up @@ -49,10 +50,14 @@ async def init_nostr_client() -> None:


async def _publish_nip47_info() -> None:
nip47_info_event = EventBuilder(
kind=KindEnum.WALLET_CONNECT_INFO(), # pyre-ignore[6]
content=" ".join([method.value for method in list(Nip47RequestMethod)]),
).build()
nip47_info_event = (
EventBuilder(
kind=KindEnum.WALLET_CONNECT_INFO(), # pyre-ignore[6]
content=" ".join([method.value for method in list(Nip47RequestMethod)]),
)
.add_tag(["v", " ".join(NWC_VERSIONS_SUPPORTED)])
.build()
)
response = await nostr_client.send_event(nip47_info_event)

logging.debug(
Expand Down
34 changes: 34 additions & 0 deletions nwc_backend/nostr/versions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from dataclasses import dataclass
from functools import total_ordering
from typing import List


NWC_VERSIONS_SUPPORTED: List[str] = ["0.0", "1.0"]


@total_ordering
@dataclass
class ParsedVersion:
major: int
minor: int

@classmethod
def load(cls, version: str) -> "ParsedVersion":
[major, minor] = version.split(".")
return ParsedVersion(major=int(major), minor=int(minor))

def __str__(self) -> str:
return f"{self.major}.{self.minor}"

def __lt__(self, other: "ParsedVersion") -> bool:
return self.major < other.major or (
self.major == other.major and self.minor < other.minor
)


def is_version_supported(version: ParsedVersion) -> bool:
for version_str in NWC_VERSIONS_SUPPORTED:
supported_version = ParsedVersion.load(version_str)
if version.major == supported_version.major:
return version.minor <= supported_version.minor
return False
Loading