From ce8b95c2c7a700885ff533be126bd37876ccf8d7 Mon Sep 17 00:00:00 2001 From: Vlad Stan Date: Mon, 26 Jun 2023 10:21:06 +0300 Subject: [PATCH] feat: restart disconnected relays --- __init__.py | 4 +++- nostr/relay.py | 44 +++++++++++++++++++++--------------------- nostr/relay_manager.py | 37 ++++++++++++++++++++++++----------- tasks.py | 12 ++++++++++++ 4 files changed, 63 insertions(+), 34 deletions(-) diff --git a/__init__.py b/__init__.py index e98d8b8..ec6f4b4 100644 --- a/__init__.py +++ b/__init__.py @@ -36,7 +36,7 @@ def nostr_renderer(): return template_renderer(["lnbits/extensions/nostrclient/templates"]) -from .tasks import init_relays, subscribe_events +from .tasks import check_relays, init_relays, subscribe_events from .views import * # noqa from .views_api import * # noqa @@ -47,3 +47,5 @@ def nostrclient_start(): scheduled_tasks.append(task1) task2 = loop.create_task(catch_everything_and_restart(subscribe_events)) scheduled_tasks.append(task2) + task3 = loop.create_task(catch_everything_and_restart(check_relays)) + scheduled_tasks.append(task3) diff --git a/nostr/relay.py b/nostr/relay.py index cc992f0..6ff16f5 100644 --- a/nostr/relay.py +++ b/nostr/relay.py @@ -3,6 +3,7 @@ import time from queue import Queue from threading import Lock +from loguru import logger from websocket import WebSocketApp from .event import Event @@ -69,17 +70,12 @@ class Relay: def close(self): self.ws.close() + self.connected = False self.shutdown = True - def check_reconnect(self): - try: - self.close() - except: - pass - self.connected = False - if self.reconnect: - time.sleep(self.error_counter**2) - self.connect(self.ssl_options, self.proxy) + @property + def error_threshold_reached(self): + return self.error_threshold and self.error_counter > self.error_threshold @property def ping(self): @@ -104,6 +100,7 @@ class Relay: self.ws.send(message) except Exception as e: if self.shutdown: + logger.warning(f"Closing queue worker for {self.url}") break else: time.sleep(0.1) @@ -127,31 +124,34 @@ class Relay: ], } - def _on_open(self, class_obj): + def _on_open(self, _): + logger.info(f"Connected to relay: '{self.url}'.") self.connected = True - pass + - def _on_close(self, class_obj, status_code, message): - self.connected = False - if self.error_threshold and self.error_counter > self.error_threshold: - pass - else: - self.check_reconnect() - pass + def _on_close(self, _, status_code, message): + logger.warning(f"Connection to relay {self.url} closed. Status: '{status_code}'. Message: '{message}'.") + self.close() - def _on_message(self, class_obj, message: str): + + + + def _on_message(self, _, message: str): if self._is_valid_message(message): self.num_received_events += 1 self.message_pool.add_message(message, self.url) + else: + logger.debug(f"Invalid relay message: '{message}'.") - def _on_error(self, class_obj, error): + def _on_error(self, _, error): + logger.warning(f"Relay error: '{str(error)}'") self.connected = False self.error_counter += 1 - def _on_ping(self, class_obj, message): + def _on_ping(self, _*): return - def _on_pong(self, class_obj, message): + def _on_pong(self, _*): return def _is_valid_message(self, message: str) -> bool: diff --git a/nostr/relay_manager.py b/nostr/relay_manager.py index ddc833c..3ccc2fd 100644 --- a/nostr/relay_manager.py +++ b/nostr/relay_manager.py @@ -2,6 +2,8 @@ import ssl import threading +from loguru import logger + from .filter import Filters from .message_pool import MessagePool from .relay import Relay, RelayPolicy @@ -22,7 +24,7 @@ class RelayManager: self._subscriptions_lock = threading.Lock() def add_relay(self, url: str, read: bool = True, write: bool = True) -> Relay: - if url in self.relays: + if url in list(self.relays.keys()): return with self._subscriptions_lock: @@ -32,7 +34,7 @@ class RelayManager: relay = Relay(url, policy, self.message_pool, subscriptions) self.relays[url] = relay - self.open_connection( + self._open_connection( relay, {"cert_reqs": ssl.CERT_NONE} ) # NOTE: This disables ssl certificate verification @@ -62,8 +64,23 @@ class RelayManager: for relay in self.relays.values(): relay.close_subscription(id) + def check_and_restart_relays(self): + stopped_relays = [r for r in self.relays.values() if r.shutdown] + for relay in stopped_relays: + logger.info(f"Restarting connection to relay '{relay.url}'") + self._restart_relay(relay) - def open_connection(self, relay: Relay, ssl_options: dict = None, proxy: dict = None): + + def close_connections(self): + for relay in self.relays.values(): + relay.close() + + def publish_message(self, message: str): + for relay in self.relays.values(): + if relay.policy.should_write: + relay.publish(message) + + def _open_connection(self, relay: Relay, ssl_options: dict = None, proxy: dict = None): self.threads[relay.url] = threading.Thread( target=relay.connect, args=(ssl_options, proxy), @@ -79,11 +96,9 @@ class RelayManager: ) self.queue_threads[relay.url].start() - def close_connections(self): - for relay in self.relays.values(): - relay.close() - - def publish_message(self, message: str): - for relay in self.relays.values(): - if relay.policy.should_write: - relay.publish(message) + def _restart_relay(self, relay: Relay): + if relay.error_threshold_reached: + return + self.remove_relay(relay.url) + new_relay = self.add_relay(relay.url) + new_relay.error_counter = relay.error_counter \ No newline at end of file diff --git a/tasks.py b/tasks.py index 40ca9d9..05057e7 100644 --- a/tasks.py +++ b/tasks.py @@ -1,6 +1,8 @@ import asyncio import threading +from loguru import logger + from . import nostr from .crud import get_relays from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage @@ -17,6 +19,16 @@ async def init_relays(): await nostr.client.connect() +async def check_relays(): + """ Check relays that have been disconnected """ + while True: + try: + await asyncio.sleep(20) + nostr.client.relay_manager.check_and_restart_relays() + except Exception as e: + logger.warning(f"Cannot restart relays: '{str(e)}'.") + + async def subscribe_events(): while not any([r.connected for r in nostr.client.relay_manager.relays.values()]): await asyncio.sleep(2)