fix endless loop

This commit is contained in:
callebtc 2023-04-17 15:47:03 +02:00
parent 37b48b7c0c
commit 33df69c73a
5 changed files with 45 additions and 14 deletions

View file

@ -141,7 +141,7 @@ class NostrClient:
if callback_events_func: if callback_events_func:
callback_events_func(event_msg) callback_events_func(event_msg)
while self.relay_manager.message_pool.has_notices(): while self.relay_manager.message_pool.has_notices():
event_msg = self.relay_manager.message_pool.has_notices() event_msg = self.relay_manager.message_pool.get_notice()
if callback_notices_func: if callback_notices_func:
callback_notices_func(event_msg) callback_notices_func(event_msg)
while self.relay_manager.message_pool.has_eose_notices(): while self.relay_manager.message_pool.has_eose_notices():

View file

@ -33,6 +33,7 @@ class Relay:
self.subscriptions = subscriptions self.subscriptions = subscriptions
self.connected: bool = False self.connected: bool = False
self.reconnect: bool = True self.reconnect: bool = True
self.shutdown: bool = False
self.error_counter: int = 0 self.error_counter: int = 0
self.error_threshold: int = 0 self.error_threshold: int = 0
self.num_received_events: int = 0 self.num_received_events: int = 0
@ -66,6 +67,7 @@ class Relay:
def close(self): def close(self):
self.ws.close() self.ws.close()
self.shutdown = True
def check_reconnect(self): def check_reconnect(self):
try: try:
@ -85,12 +87,16 @@ class Relay:
def publish(self, message: str): def publish(self, message: str):
self.queue.put(message) self.queue.put(message)
def queue_worker(self): def queue_worker(self, shutdown):
while True: while True:
if self.connected: if self.connected:
message = self.queue.get() try:
self.num_sent_events += 1 message = self.queue.get(timeout=1)
self.ws.send(message) self.num_sent_events += 1
self.ws.send(message)
except:
if shutdown():
break
else: else:
time.sleep(0.1) time.sleep(0.1)

View file

@ -15,6 +15,8 @@ class RelayException(Exception):
class RelayManager: class RelayManager:
def __init__(self) -> None: def __init__(self) -> None:
self.relays: dict[str, Relay] = {} self.relays: dict[str, Relay] = {}
self.threads: dict[str, threading.Thread] = {}
self.queue_threads: dict[str, threading.Thread] = {}
self.message_pool = MessagePool() self.message_pool = MessagePool()
def add_relay( def add_relay(
@ -25,7 +27,10 @@ class RelayManager:
self.relays[url] = relay self.relays[url] = relay
def remove_relay(self, url: str): def remove_relay(self, url: str):
self.relays[url].close()
self.relays.pop(url) self.relays.pop(url)
self.threads[url].join(timeout=1)
self.threads.pop(url)
def add_subscription(self, id: str, filters: Filters): def add_subscription(self, id: str, filters: Filters):
for relay in self.relays.values(): for relay in self.relays.values():
@ -37,16 +42,21 @@ class RelayManager:
def open_connections(self, ssl_options: dict = None, proxy: dict = None): def open_connections(self, ssl_options: dict = None, proxy: dict = None):
for relay in self.relays.values(): for relay in self.relays.values():
threading.Thread( self.threads[relay.url] = threading.Thread(
target=relay.connect, target=relay.connect,
args=(ssl_options, proxy), args=(ssl_options, proxy),
name=f"{relay.url}-thread", name=f"{relay.url}-thread",
daemon=True, daemon=True,
).start() )
self.threads[relay.url].start()
threading.Thread( self.queue_threads[relay.url] = threading.Thread(
target=relay.queue_worker, name=f"{relay.url}-queue", daemon=True target=relay.queue_worker,
).start() args=(lambda: relay.shutdown,),
name=f"{relay.url}-queue",
daemon=True,
)
self.queue_threads[relay.url].start()
def close_connections(self): def close_connections(self):
for relay in self.relays.values(): for relay in self.relays.values():

View file

@ -14,7 +14,7 @@ from .nostr.filter import Filters as NostrFilters
from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage
received_subscription_events: dict[str, list[Event]] = {} received_subscription_events: dict[str, list[Event]] = {}
received_subscription_notices: dict[str, list[NoticeMessage]] = {} received_subscription_notices: list[NoticeMessage] = []
received_subscription_eosenotices: dict[str, EndOfStoredEventsMessage] = {} received_subscription_eosenotices: dict[str, EndOfStoredEventsMessage] = {}
@ -62,7 +62,8 @@ class NostrRouter:
stored in `my_subscriptions`. Then gets all responses for this subscription id from `received_subscription_events` which stored in `my_subscriptions`. Then gets all responses for this subscription id from `received_subscription_events` which
is filled in tasks.py. Takes one response after the other and relays it back to the client. Reconstructs is filled in tasks.py. Takes one response after the other and relays it back to the client. Reconstructs
the reponse manually because the nostr client lib we're using can't do it. Reconstructs the original subscription id the reponse manually because the nostr client lib we're using can't do it. Reconstructs the original subscription id
that we had previously rewritten in order to avoid collisions when multiple clients use the same id.""" that we had previously rewritten in order to avoid collisions when multiple clients use the same id.
"""
while True and self.connected: while True and self.connected:
for s in self.subscriptions: for s in self.subscriptions:
if s in received_subscription_events: if s in received_subscription_events:
@ -93,7 +94,17 @@ class NostrRouter:
event_to_forward = ["EOSE", s_original] event_to_forward = ["EOSE", s_original]
del received_subscription_eosenotices[s] del received_subscription_eosenotices[s]
# send data back to client # send data back to client
# print("Sending EOSE", event_to_forward)
await self.websocket.send_text(json.dumps(event_to_forward)) await self.websocket.send_text(json.dumps(event_to_forward))
# if s in received_subscription_notices:
while len(received_subscription_notices):
my_event = received_subscription_notices.pop(0)
event_to_forward = ["NOTICE", my_event.content]
# send data back to client
print("Received notice", event_to_forward)
# note: we don't send it to the user because we don't know who should receive it
# await self.websocket.send_text(json.dumps(event_to_forward))
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
async def start(self): async def start(self):
@ -128,7 +139,8 @@ class NostrRouter:
"""Parses a (string) request from a client. If it is a subscription (REQ), it will """Parses a (string) request from a client. If it is a subscription (REQ), it will
register the subscription in the nostr client library that we're using so we can register the subscription in the nostr client library that we're using so we can
receive the callbacks on it later. Will rewrite the subscription id since we expect receive the callbacks on it later. Will rewrite the subscription id since we expect
multiple clients to use the router and want to avoid subscription id collisions""" multiple clients to use the router and want to avoid subscription id collisions
"""
json_data = json.loads(json_str) json_data = json.loads(json_str)
assert len(json_data) assert len(json_data)
if json_data[0] == "REQ": if json_data[0] == "REQ":

View file

@ -11,6 +11,7 @@ from .nostr.relay_manager import RelayManager
from .services import ( from .services import (
nostr, nostr,
received_subscription_eosenotices, received_subscription_eosenotices,
received_subscription_notices,
received_subscription_events, received_subscription_events,
) )
@ -68,7 +69,9 @@ async def subscribe_events():
] ]
return return
def callback_notices(eventMessage: NoticeMessage): def callback_notices(noticeMessage: NoticeMessage):
if noticeMessage not in received_subscription_notices:
received_subscription_notices.append(noticeMessage)
return return
def callback_eose_notices(eventMessage: EndOfStoredEventsMessage): def callback_eose_notices(eventMessage: EndOfStoredEventsMessage):