diff --git a/nostr/client/client.py b/nostr/client/client.py index 6fb885f..6e70f71 100644 --- a/nostr/client/client.py +++ b/nostr/client/client.py @@ -141,7 +141,7 @@ class NostrClient: if callback_events_func: callback_events_func(event_msg) while self.relay_manager.message_pool.has_notices(): - event_msg = self.relay_manager.message_pool.has_notices() + event_msg = self.relay_manager.message_pool.get_notice() if callback_notices_func: callback_notices_func(event_msg) while self.relay_manager.message_pool.has_eose_notices(): diff --git a/nostr/relay.py b/nostr/relay.py index ee78baa..db9cacf 100644 --- a/nostr/relay.py +++ b/nostr/relay.py @@ -33,6 +33,7 @@ class Relay: self.subscriptions = subscriptions self.connected: bool = False self.reconnect: bool = True + self.shutdown: bool = False self.error_counter: int = 0 self.error_threshold: int = 0 self.num_received_events: int = 0 @@ -66,6 +67,7 @@ class Relay: def close(self): self.ws.close() + self.shutdown = True def check_reconnect(self): try: @@ -85,12 +87,16 @@ class Relay: def publish(self, message: str): self.queue.put(message) - def queue_worker(self): + def queue_worker(self, shutdown): while True: if self.connected: - message = self.queue.get() - self.num_sent_events += 1 - self.ws.send(message) + try: + message = self.queue.get(timeout=1) + self.num_sent_events += 1 + self.ws.send(message) + except: + if shutdown(): + break else: time.sleep(0.1) diff --git a/nostr/relay_manager.py b/nostr/relay_manager.py index 5b92d8d..a698a33 100644 --- a/nostr/relay_manager.py +++ b/nostr/relay_manager.py @@ -15,6 +15,8 @@ class RelayException(Exception): class RelayManager: def __init__(self) -> None: self.relays: dict[str, Relay] = {} + self.threads: dict[str, threading.Thread] = {} + self.queue_threads: dict[str, threading.Thread] = {} self.message_pool = MessagePool() def add_relay( @@ -25,7 +27,10 @@ class RelayManager: self.relays[url] = relay def remove_relay(self, url: str): + self.relays[url].close() self.relays.pop(url) + self.threads[url].join(timeout=1) + self.threads.pop(url) def add_subscription(self, id: str, filters: Filters): for relay in self.relays.values(): @@ -37,16 +42,21 @@ class RelayManager: def open_connections(self, ssl_options: dict = None, proxy: dict = None): for relay in self.relays.values(): - threading.Thread( + self.threads[relay.url] = threading.Thread( target=relay.connect, args=(ssl_options, proxy), name=f"{relay.url}-thread", daemon=True, - ).start() + ) + self.threads[relay.url].start() - threading.Thread( - target=relay.queue_worker, name=f"{relay.url}-queue", daemon=True - ).start() + self.queue_threads[relay.url] = threading.Thread( + target=relay.queue_worker, + args=(lambda: relay.shutdown,), + name=f"{relay.url}-queue", + daemon=True, + ) + self.queue_threads[relay.url].start() def close_connections(self): for relay in self.relays.values(): diff --git a/services.py b/services.py index e03ad1d..a801673 100644 --- a/services.py +++ b/services.py @@ -14,7 +14,7 @@ from .nostr.filter import Filters as NostrFilters from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage received_subscription_events: dict[str, list[Event]] = {} -received_subscription_notices: dict[str, list[NoticeMessage]] = {} +received_subscription_notices: list[NoticeMessage] = [] received_subscription_eosenotices: dict[str, EndOfStoredEventsMessage] = {} @@ -62,7 +62,8 @@ class NostrRouter: stored in `my_subscriptions`. Then gets all responses for this subscription id from `received_subscription_events` which is filled in tasks.py. Takes one response after the other and relays it back to the client. Reconstructs the reponse manually because the nostr client lib we're using can't do it. Reconstructs the original subscription id - that we had previously rewritten in order to avoid collisions when multiple clients use the same id.""" + that we had previously rewritten in order to avoid collisions when multiple clients use the same id. + """ while True and self.connected: for s in self.subscriptions: if s in received_subscription_events: @@ -93,7 +94,17 @@ class NostrRouter: event_to_forward = ["EOSE", s_original] del received_subscription_eosenotices[s] # send data back to client + # print("Sending EOSE", event_to_forward) await self.websocket.send_text(json.dumps(event_to_forward)) + + # if s in received_subscription_notices: + while len(received_subscription_notices): + my_event = received_subscription_notices.pop(0) + event_to_forward = ["NOTICE", my_event.content] + # send data back to client + print("Received notice", event_to_forward) + # note: we don't send it to the user because we don't know who should receive it + # await self.websocket.send_text(json.dumps(event_to_forward)) await asyncio.sleep(0.1) async def start(self): @@ -128,7 +139,8 @@ class NostrRouter: """Parses a (string) request from a client. If it is a subscription (REQ), it will register the subscription in the nostr client library that we're using so we can receive the callbacks on it later. Will rewrite the subscription id since we expect - multiple clients to use the router and want to avoid subscription id collisions""" + multiple clients to use the router and want to avoid subscription id collisions + """ json_data = json.loads(json_str) assert len(json_data) if json_data[0] == "REQ": diff --git a/tasks.py b/tasks.py index 790337c..ab9a656 100644 --- a/tasks.py +++ b/tasks.py @@ -11,6 +11,7 @@ from .nostr.relay_manager import RelayManager from .services import ( nostr, received_subscription_eosenotices, + received_subscription_notices, received_subscription_events, ) @@ -68,7 +69,9 @@ async def subscribe_events(): ] return - def callback_notices(eventMessage: NoticeMessage): + def callback_notices(noticeMessage: NoticeMessage): + if noticeMessage not in received_subscription_notices: + received_subscription_notices.append(noticeMessage) return def callback_eose_notices(eventMessage: EndOfStoredEventsMessage):