feat: restart disconnected relays
This commit is contained in:
parent
c81a804dea
commit
ce8b95c2c7
4 changed files with 63 additions and 34 deletions
|
|
@ -36,7 +36,7 @@ def nostr_renderer():
|
|||
return template_renderer(["lnbits/extensions/nostrclient/templates"])
|
||||
|
||||
|
||||
from .tasks import init_relays, subscribe_events
|
||||
from .tasks import check_relays, init_relays, subscribe_events
|
||||
from .views import * # noqa
|
||||
from .views_api import * # noqa
|
||||
|
||||
|
|
@ -47,3 +47,5 @@ def nostrclient_start():
|
|||
scheduled_tasks.append(task1)
|
||||
task2 = loop.create_task(catch_everything_and_restart(subscribe_events))
|
||||
scheduled_tasks.append(task2)
|
||||
task3 = loop.create_task(catch_everything_and_restart(check_relays))
|
||||
scheduled_tasks.append(task3)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import time
|
|||
from queue import Queue
|
||||
from threading import Lock
|
||||
|
||||
from loguru import logger
|
||||
from websocket import WebSocketApp
|
||||
|
||||
from .event import Event
|
||||
|
|
@ -69,17 +70,12 @@ class Relay:
|
|||
|
||||
def close(self):
|
||||
self.ws.close()
|
||||
self.connected = False
|
||||
self.shutdown = True
|
||||
|
||||
def check_reconnect(self):
|
||||
try:
|
||||
self.close()
|
||||
except:
|
||||
pass
|
||||
self.connected = False
|
||||
if self.reconnect:
|
||||
time.sleep(self.error_counter**2)
|
||||
self.connect(self.ssl_options, self.proxy)
|
||||
@property
|
||||
def error_threshold_reached(self):
|
||||
return self.error_threshold and self.error_counter > self.error_threshold
|
||||
|
||||
@property
|
||||
def ping(self):
|
||||
|
|
@ -104,6 +100,7 @@ class Relay:
|
|||
self.ws.send(message)
|
||||
except Exception as e:
|
||||
if self.shutdown:
|
||||
logger.warning(f"Closing queue worker for {self.url}")
|
||||
break
|
||||
else:
|
||||
time.sleep(0.1)
|
||||
|
|
@ -127,31 +124,34 @@ class Relay:
|
|||
],
|
||||
}
|
||||
|
||||
def _on_open(self, class_obj):
|
||||
def _on_open(self, _):
|
||||
logger.info(f"Connected to relay: '{self.url}'.")
|
||||
self.connected = True
|
||||
pass
|
||||
|
||||
def _on_close(self, class_obj, status_code, message):
|
||||
self.connected = False
|
||||
if self.error_threshold and self.error_counter > self.error_threshold:
|
||||
pass
|
||||
else:
|
||||
self.check_reconnect()
|
||||
pass
|
||||
|
||||
def _on_message(self, class_obj, message: str):
|
||||
def _on_close(self, _, status_code, message):
|
||||
logger.warning(f"Connection to relay {self.url} closed. Status: '{status_code}'. Message: '{message}'.")
|
||||
self.close()
|
||||
|
||||
|
||||
|
||||
|
||||
def _on_message(self, _, message: str):
|
||||
if self._is_valid_message(message):
|
||||
self.num_received_events += 1
|
||||
self.message_pool.add_message(message, self.url)
|
||||
else:
|
||||
logger.debug(f"Invalid relay message: '{message}'.")
|
||||
|
||||
def _on_error(self, class_obj, error):
|
||||
def _on_error(self, _, error):
|
||||
logger.warning(f"Relay error: '{str(error)}'")
|
||||
self.connected = False
|
||||
self.error_counter += 1
|
||||
|
||||
def _on_ping(self, class_obj, message):
|
||||
def _on_ping(self, _*):
|
||||
return
|
||||
|
||||
def _on_pong(self, class_obj, message):
|
||||
def _on_pong(self, _*):
|
||||
return
|
||||
|
||||
def _is_valid_message(self, message: str) -> bool:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
import ssl
|
||||
import threading
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from .filter import Filters
|
||||
from .message_pool import MessagePool
|
||||
from .relay import Relay, RelayPolicy
|
||||
|
|
@ -22,7 +24,7 @@ class RelayManager:
|
|||
self._subscriptions_lock = threading.Lock()
|
||||
|
||||
def add_relay(self, url: str, read: bool = True, write: bool = True) -> Relay:
|
||||
if url in self.relays:
|
||||
if url in list(self.relays.keys()):
|
||||
return
|
||||
|
||||
with self._subscriptions_lock:
|
||||
|
|
@ -32,7 +34,7 @@ class RelayManager:
|
|||
relay = Relay(url, policy, self.message_pool, subscriptions)
|
||||
self.relays[url] = relay
|
||||
|
||||
self.open_connection(
|
||||
self._open_connection(
|
||||
relay,
|
||||
{"cert_reqs": ssl.CERT_NONE}
|
||||
) # NOTE: This disables ssl certificate verification
|
||||
|
|
@ -62,8 +64,23 @@ class RelayManager:
|
|||
for relay in self.relays.values():
|
||||
relay.close_subscription(id)
|
||||
|
||||
def check_and_restart_relays(self):
|
||||
stopped_relays = [r for r in self.relays.values() if r.shutdown]
|
||||
for relay in stopped_relays:
|
||||
logger.info(f"Restarting connection to relay '{relay.url}'")
|
||||
self._restart_relay(relay)
|
||||
|
||||
def open_connection(self, relay: Relay, ssl_options: dict = None, proxy: dict = None):
|
||||
|
||||
def close_connections(self):
|
||||
for relay in self.relays.values():
|
||||
relay.close()
|
||||
|
||||
def publish_message(self, message: str):
|
||||
for relay in self.relays.values():
|
||||
if relay.policy.should_write:
|
||||
relay.publish(message)
|
||||
|
||||
def _open_connection(self, relay: Relay, ssl_options: dict = None, proxy: dict = None):
|
||||
self.threads[relay.url] = threading.Thread(
|
||||
target=relay.connect,
|
||||
args=(ssl_options, proxy),
|
||||
|
|
@ -79,11 +96,9 @@ class RelayManager:
|
|||
)
|
||||
self.queue_threads[relay.url].start()
|
||||
|
||||
def close_connections(self):
|
||||
for relay in self.relays.values():
|
||||
relay.close()
|
||||
|
||||
def publish_message(self, message: str):
|
||||
for relay in self.relays.values():
|
||||
if relay.policy.should_write:
|
||||
relay.publish(message)
|
||||
def _restart_relay(self, relay: Relay):
|
||||
if relay.error_threshold_reached:
|
||||
return
|
||||
self.remove_relay(relay.url)
|
||||
new_relay = self.add_relay(relay.url)
|
||||
new_relay.error_counter = relay.error_counter
|
||||
12
tasks.py
12
tasks.py
|
|
@ -1,6 +1,8 @@
|
|||
import asyncio
|
||||
import threading
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from . import nostr
|
||||
from .crud import get_relays
|
||||
from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage
|
||||
|
|
@ -17,6 +19,16 @@ async def init_relays():
|
|||
await nostr.client.connect()
|
||||
|
||||
|
||||
async def check_relays():
|
||||
""" Check relays that have been disconnected """
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(20)
|
||||
nostr.client.relay_manager.check_and_restart_relays()
|
||||
except Exception as e:
|
||||
logger.warning(f"Cannot restart relays: '{str(e)}'.")
|
||||
|
||||
|
||||
async def subscribe_events():
|
||||
while not any([r.connected for r in nostr.client.relay_manager.relays.values()]):
|
||||
await asyncio.sleep(2)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue