feat: restart disconnected relays

This commit is contained in:
Vlad Stan 2023-06-26 10:21:06 +03:00
parent c81a804dea
commit ce8b95c2c7
4 changed files with 63 additions and 34 deletions

View file

@ -36,7 +36,7 @@ def nostr_renderer():
return template_renderer(["lnbits/extensions/nostrclient/templates"]) return template_renderer(["lnbits/extensions/nostrclient/templates"])
from .tasks import init_relays, subscribe_events from .tasks import check_relays, init_relays, subscribe_events
from .views import * # noqa from .views import * # noqa
from .views_api import * # noqa from .views_api import * # noqa
@ -47,3 +47,5 @@ def nostrclient_start():
scheduled_tasks.append(task1) scheduled_tasks.append(task1)
task2 = loop.create_task(catch_everything_and_restart(subscribe_events)) task2 = loop.create_task(catch_everything_and_restart(subscribe_events))
scheduled_tasks.append(task2) scheduled_tasks.append(task2)
task3 = loop.create_task(catch_everything_and_restart(check_relays))
scheduled_tasks.append(task3)

View file

@ -3,6 +3,7 @@ import time
from queue import Queue from queue import Queue
from threading import Lock from threading import Lock
from loguru import logger
from websocket import WebSocketApp from websocket import WebSocketApp
from .event import Event from .event import Event
@ -69,17 +70,12 @@ class Relay:
def close(self): def close(self):
self.ws.close() self.ws.close()
self.connected = False
self.shutdown = True self.shutdown = True
def check_reconnect(self): @property
try: def error_threshold_reached(self):
self.close() return self.error_threshold and self.error_counter > self.error_threshold
except:
pass
self.connected = False
if self.reconnect:
time.sleep(self.error_counter**2)
self.connect(self.ssl_options, self.proxy)
@property @property
def ping(self): def ping(self):
@ -104,6 +100,7 @@ class Relay:
self.ws.send(message) self.ws.send(message)
except Exception as e: except Exception as e:
if self.shutdown: if self.shutdown:
logger.warning(f"Closing queue worker for {self.url}")
break break
else: else:
time.sleep(0.1) time.sleep(0.1)
@ -127,31 +124,34 @@ class Relay:
], ],
} }
def _on_open(self, class_obj): def _on_open(self, _):
logger.info(f"Connected to relay: '{self.url}'.")
self.connected = True self.connected = True
pass
def _on_close(self, class_obj, status_code, message): def _on_close(self, _, status_code, message):
self.connected = False logger.warning(f"Connection to relay {self.url} closed. Status: '{status_code}'. Message: '{message}'.")
if self.error_threshold and self.error_counter > self.error_threshold: self.close()
pass
else:
self.check_reconnect()
pass
def _on_message(self, class_obj, message: str):
def _on_message(self, _, message: str):
if self._is_valid_message(message): if self._is_valid_message(message):
self.num_received_events += 1 self.num_received_events += 1
self.message_pool.add_message(message, self.url) self.message_pool.add_message(message, self.url)
else:
logger.debug(f"Invalid relay message: '{message}'.")
def _on_error(self, class_obj, error): def _on_error(self, _, error):
logger.warning(f"Relay error: '{str(error)}'")
self.connected = False self.connected = False
self.error_counter += 1 self.error_counter += 1
def _on_ping(self, class_obj, message): def _on_ping(self, _*):
return return
def _on_pong(self, class_obj, message): def _on_pong(self, _*):
return return
def _is_valid_message(self, message: str) -> bool: def _is_valid_message(self, message: str) -> bool:

View file

@ -2,6 +2,8 @@
import ssl import ssl
import threading import threading
from loguru import logger
from .filter import Filters from .filter import Filters
from .message_pool import MessagePool from .message_pool import MessagePool
from .relay import Relay, RelayPolicy from .relay import Relay, RelayPolicy
@ -22,7 +24,7 @@ class RelayManager:
self._subscriptions_lock = threading.Lock() self._subscriptions_lock = threading.Lock()
def add_relay(self, url: str, read: bool = True, write: bool = True) -> Relay: def add_relay(self, url: str, read: bool = True, write: bool = True) -> Relay:
if url in self.relays: if url in list(self.relays.keys()):
return return
with self._subscriptions_lock: with self._subscriptions_lock:
@ -32,7 +34,7 @@ class RelayManager:
relay = Relay(url, policy, self.message_pool, subscriptions) relay = Relay(url, policy, self.message_pool, subscriptions)
self.relays[url] = relay self.relays[url] = relay
self.open_connection( self._open_connection(
relay, relay,
{"cert_reqs": ssl.CERT_NONE} {"cert_reqs": ssl.CERT_NONE}
) # NOTE: This disables ssl certificate verification ) # NOTE: This disables ssl certificate verification
@ -62,8 +64,23 @@ class RelayManager:
for relay in self.relays.values(): for relay in self.relays.values():
relay.close_subscription(id) relay.close_subscription(id)
def check_and_restart_relays(self):
stopped_relays = [r for r in self.relays.values() if r.shutdown]
for relay in stopped_relays:
logger.info(f"Restarting connection to relay '{relay.url}'")
self._restart_relay(relay)
def open_connection(self, relay: Relay, ssl_options: dict = None, proxy: dict = None):
def close_connections(self):
for relay in self.relays.values():
relay.close()
def publish_message(self, message: str):
for relay in self.relays.values():
if relay.policy.should_write:
relay.publish(message)
def _open_connection(self, relay: Relay, ssl_options: dict = None, proxy: dict = None):
self.threads[relay.url] = threading.Thread( self.threads[relay.url] = threading.Thread(
target=relay.connect, target=relay.connect,
args=(ssl_options, proxy), args=(ssl_options, proxy),
@ -79,11 +96,9 @@ class RelayManager:
) )
self.queue_threads[relay.url].start() self.queue_threads[relay.url].start()
def close_connections(self): def _restart_relay(self, relay: Relay):
for relay in self.relays.values(): if relay.error_threshold_reached:
relay.close() return
self.remove_relay(relay.url)
def publish_message(self, message: str): new_relay = self.add_relay(relay.url)
for relay in self.relays.values(): new_relay.error_counter = relay.error_counter
if relay.policy.should_write:
relay.publish(message)

View file

@ -1,6 +1,8 @@
import asyncio import asyncio
import threading import threading
from loguru import logger
from . import nostr from . import nostr
from .crud import get_relays from .crud import get_relays
from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage
@ -17,6 +19,16 @@ async def init_relays():
await nostr.client.connect() await nostr.client.connect()
async def check_relays():
""" Check relays that have been disconnected """
while True:
try:
await asyncio.sleep(20)
nostr.client.relay_manager.check_and_restart_relays()
except Exception as e:
logger.warning(f"Cannot restart relays: '{str(e)}'.")
async def subscribe_events(): async def subscribe_events():
while not any([r.connected for r in nostr.client.relay_manager.relays.values()]): while not any([r.connected for r in nostr.client.relay_manager.relays.values()]):
await asyncio.sleep(2) await asyncio.sleep(2)