diff --git a/nostr/relay.py b/nostr/relay.py index 94e532c..8d3545b 100644 --- a/nostr/relay.py +++ b/nostr/relay.py @@ -121,6 +121,7 @@ class Relay: def close_subscription(self, id: str) -> None: with self.lock: self.subscriptions.pop(id) + self.publish(json.dumps(["CLOSE", id])) def to_json_object(self) -> dict: return { diff --git a/services.py b/services.py index b270539..42e8c2c 100644 --- a/services.py +++ b/services.py @@ -2,17 +2,16 @@ import asyncio import json from typing import List, Union -from fastapi import WebSocket, WebSocketDisconnect +from fastapi import WebSocketDisconnect from loguru import logger from lnbits.helpers import urlsafe_short_hash -from .models import Event, Filter, Filters, Relay, RelayList +from .models import Event, Filter from .nostr.client.client import NostrClient as NostrClientLib -from .nostr.event import Event as NostrEvent from .nostr.filter import Filter as NostrFilter from .nostr.filter import Filters as NostrFilters -from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage +from .nostr.message_pool import EndOfStoredEventsMessage, NoticeMessage received_subscription_events: dict[str, list[Event]] = {} received_subscription_notices: list[NoticeMessage] = [] @@ -33,7 +32,7 @@ class NostrRouter: self.connected: bool = True self.websocket = websocket self.tasks: List[asyncio.Task] = [] - self.oridinal_subscription_ids = {} + self.original_subscription_ids = {} async def client_to_nostr(self): """Receives requests / data from the client and forwards it to relays. If the @@ -47,17 +46,7 @@ class NostrRouter: self.connected = False break - # registers a subscription if the input was a REQ request - subscription_id, json_str_rewritten = await self._handle_nostr_subscription( - json_str - ) - - if subscription_id and json_str_rewritten: - self.subscriptions.append(subscription_id) - - # publish data - publish_data = json_str_rewritten or json_str - nostr.client.relay_manager.publish_message(publish_data) + await self._handle_client_to_nostr(json_str) async def nostr_to_client(self): """Sends responses from relays back to the client. Polls the subscriptions of this client @@ -67,50 +56,12 @@ class NostrRouter: that we had previously rewritten in order to avoid collisions when multiple clients use the same id. """ while True and self.connected: - for s in self.subscriptions: - if s in received_subscription_events: - while len(received_subscription_events[s]): - my_event = received_subscription_events[s].pop(0) - # event.to_message() does not include the subscription ID, we have to add it manually - event_json = { - "id": my_event.id, - "pubkey": my_event.public_key, - "created_at": my_event.created_at, - "kind": my_event.kind, - "tags": my_event.tags, - "content": my_event.content, - "sig": my_event.signature, - } - - # this reconstructs the original response from the relay - # reconstruct original subscription id - s_original = self.oridinal_subscription_ids[s] - event_to_forward = ["EVENT", s_original, event_json] - - # print("Event to forward") - # print(json.dumps(event_to_forward)) - - # send data back to client - await self.websocket.send_text(json.dumps(event_to_forward)) - if s in received_subscription_eosenotices: - my_event = received_subscription_eosenotices[s] - s_original = self.oridinal_subscription_ids[s] - event_to_forward = ["EOSE", s_original] - del received_subscription_eosenotices[s] - # send data back to client - # print("Sending EOSE", event_to_forward) - await self.websocket.send_text(json.dumps(event_to_forward)) - - # if s in received_subscription_notices: - while len(received_subscription_notices): - my_event = received_subscription_notices.pop(0) - event_to_forward = ["NOTICE", my_event.content] - # send data back to client - logger.debug("Nostrclient: Received notice", event_to_forward[1]) - # note: we don't send it to the user because we don't know who should receive it - # await self.websocket.send_text(json.dumps(event_to_forward)) + await self._handle_subscriptions() + self._handle_notices() + await asyncio.sleep(0.1) + async def start(self): self.tasks.append(asyncio.create_task(self.client_to_nostr())) self.tasks.append(asyncio.create_task(self.nostr_to_client())) @@ -120,6 +71,53 @@ class NostrRouter: t.cancel() self.connected = False + async def _handle_subscriptions(self): + for s in self.subscriptions: + if s in received_subscription_events: + await self._handle_received_subscription_events(s) + if s in received_subscription_eosenotices: + await self._handle_received_subscription_eosenotices(s) + + + + async def _handle_received_subscription_eosenotices(self, s): + my_event = received_subscription_eosenotices[s] + s_original = self.original_subscription_ids[s] + event_to_forward = ["EOSE", s_original] + del received_subscription_eosenotices[s] + # send data back to client + # print("Sending EOSE", event_to_forward) + await self.websocket.send_text(json.dumps(event_to_forward)) + + async def _handle_received_subscription_events(self, s): + while len(received_subscription_events[s]): + my_event = received_subscription_events[s].pop(0) + # event.to_message() does not include the subscription ID, we have to add it manually + event_json = { + "id": my_event.id, + "pubkey": my_event.public_key, + "created_at": my_event.created_at, + "kind": my_event.kind, + "tags": my_event.tags, + "content": my_event.content, + "sig": my_event.signature, + } + + # this reconstructs the original response from the relay + # reconstruct original subscription id + s_original = self.original_subscription_ids[s] + event_to_forward = ["EVENT", s_original, event_json] + await self.websocket.send_text(json.dumps(event_to_forward)) + + def _handle_notices(self): + while len(received_subscription_notices): + my_event = received_subscription_notices.pop(0) + event_to_forward = ["NOTICE", my_event.content] + # note: we don't send it to the user because we don't know who should receive it + logger.debug("Nostrclient: Received notice", event_to_forward[1]) + + + def _marshall_nostr_filters(self, data: Union[dict, list]): filters = data if isinstance(data, list) else [data] filters = [Filter.parse_obj(f) for f in filters] @@ -139,7 +137,7 @@ class NostrRouter: ) return NostrFilters(filter_list) - async def _handle_nostr_subscription(self, json_str): + async def _handle_client_to_nostr(self, json_str): """Parses a (string) request from a client. If it is a subscription (REQ) or a CLOSE, it will register the subscription in the nostr client library that we're using so we can receive the callbacks on it later. Will rewrite the subscription id since we expect @@ -147,25 +145,35 @@ class NostrRouter: """ json_data = json.loads(json_str) assert len(json_data) + if json_data[0] == "REQ": - subscription_id = json_data[1] - subscription_id_rewritten = urlsafe_short_hash() - self.oridinal_subscription_ids[subscription_id_rewritten] = subscription_id - fltr = json_data[2] - filters = self._marshall_nostr_filters(fltr) - nostr.client.relay_manager.add_subscription( + self._handle_client_req(json_data) + return + + if json_data[0] == "CLOSE": + self.handle_client_close(json_data[1]) + return + + if json_data[0] == "EVENT": + nostr.client.relay_manager.publish_message(json_str) + return + + def _handle_client_req(self, json_data): + subscription_id = json_data[1] + subscription_id_rewritten = urlsafe_short_hash() + self.original_subscription_ids[subscription_id_rewritten] = subscription_id + fltr = json_data[2] + filters = self._marshall_nostr_filters(fltr) + + nostr.client.relay_manager.add_subscription( subscription_id_rewritten, filters ) - request_rewritten = json.dumps( - [json_data[0], subscription_id_rewritten, fltr] - ) - return subscription_id_rewritten, request_rewritten - elif json_data[0] == "CLOSE": - subscription_id = json_data[1] - subscription_id_rewritten = next((k for k, v in self.oridinal_subscription_ids.items() if v == subscription_id), None) - if subscription_id_rewritten: - nostr.client.relay_manager.close_subscription(subscription_id_rewritten) - request_rewritten = json.dumps([json_data[0], subscription_id_rewritten]) - return None, request_rewritten + request_rewritten = json.dumps([json_data[0], subscription_id_rewritten, fltr]) + + self.subscriptions.append(subscription_id_rewritten) + nostr.client.relay_manager.publish_message(request_rewritten) - return None, None + def handle_client_close(self, subscription_id): + subscription_id_rewritten = next((k for k, v in self.original_subscription_ids.items() if v == subscription_id), None) + if subscription_id_rewritten: + nostr.client.relay_manager.close_subscription(subscription_id_rewritten)