* refactor: clean-up * refactor: extra logs plus try-catch * refactor: do not use bare `except` * refactor: clean-up redundant fields * chore: pass code checks * chore: code format * refactor: code clean-up * fix: refactoring stuff * refactor: remove un-used file * chore: code clean-up * chore: code clean-up * chore: code-format fix * refactor: remove nostr.client wrapper * refactor: code clean-up * chore: code format * refactor: remove `RelayList` class * refactor: extract smaller methods with try-catch * fix: better exception handling * fix: remove redundant filters * fix: simplify event * chore: code format * fix: code check * fix: code check * fix: simplify `REQ` * fix: more clean-ups * refactor: use simpler method * refactor: re-order and rename * fix: stop logic * fix: subscription close before disconnect * chore: play commit
142 lines
4.3 KiB
Python
142 lines
4.3 KiB
Python
import asyncio
|
|
import threading
|
|
import time
|
|
from typing import List
|
|
|
|
from loguru import logger
|
|
|
|
from .message_pool import MessagePool, NoticeMessage
|
|
from .relay import Relay
|
|
from .subscription import Subscription
|
|
|
|
|
|
class RelayManager:
|
|
def __init__(self) -> None:
|
|
self.relays: dict[str, Relay] = {}
|
|
self.threads: dict[str, threading.Thread] = {}
|
|
self.queue_threads: dict[str, threading.Thread] = {}
|
|
self.message_pool = MessagePool()
|
|
self._cached_subscriptions: dict[str, Subscription] = {}
|
|
self._subscriptions_lock = threading.Lock()
|
|
|
|
def add_relay(self, url: str) -> Relay:
|
|
if url in list(self.relays.keys()):
|
|
logger.debug(f"Relay '{url}' already present.")
|
|
return self.relays[url]
|
|
|
|
relay = Relay(url, self.message_pool)
|
|
self.relays[url] = relay
|
|
|
|
self._open_connection(relay)
|
|
|
|
relay.publish_subscriptions(list(self._cached_subscriptions.values()))
|
|
return relay
|
|
|
|
def remove_relay(self, url: str):
|
|
try:
|
|
self.relays[url].close()
|
|
except Exception as e:
|
|
logger.debug(e)
|
|
|
|
if url in self.relays:
|
|
self.relays.pop(url)
|
|
|
|
try:
|
|
self.threads[url].join(timeout=5)
|
|
except Exception as e:
|
|
logger.debug(e)
|
|
|
|
if url in self.threads:
|
|
self.threads.pop(url)
|
|
|
|
try:
|
|
self.queue_threads[url].join(timeout=5)
|
|
except Exception as e:
|
|
logger.debug(e)
|
|
|
|
if url in self.queue_threads:
|
|
self.queue_threads.pop(url)
|
|
|
|
def remove_relays(self):
|
|
relay_urls = list(self.relays.keys())
|
|
for url in relay_urls:
|
|
self.remove_relay(url)
|
|
|
|
def add_subscription(self, id: str, filters: List[str]):
|
|
s = Subscription(id, filters)
|
|
with self._subscriptions_lock:
|
|
self._cached_subscriptions[id] = s
|
|
|
|
for relay in self.relays.values():
|
|
relay.publish_subscriptions([s])
|
|
|
|
def close_subscription(self, id: str):
|
|
try:
|
|
with self._subscriptions_lock:
|
|
if id in self._cached_subscriptions:
|
|
self._cached_subscriptions.pop(id)
|
|
|
|
for relay in self.relays.values():
|
|
relay.close_subscription(id)
|
|
except Exception as e:
|
|
logger.debug(e)
|
|
|
|
def close_subscriptions(self, subscriptions: List[str]):
|
|
for id in subscriptions:
|
|
self.close_subscription(id)
|
|
|
|
def close_all_subscriptions(self):
|
|
all_subscriptions = list(self._cached_subscriptions.keys())
|
|
self.close_subscriptions(all_subscriptions)
|
|
|
|
def check_and_restart_relays(self):
|
|
stopped_relays = [r for r in self.relays.values() if r.shutdown]
|
|
for relay in stopped_relays:
|
|
self._restart_relay(relay)
|
|
|
|
def close_connections(self):
|
|
for relay in self.relays.values():
|
|
relay.close()
|
|
|
|
def publish_message(self, message: str):
|
|
for relay in self.relays.values():
|
|
relay.publish(message)
|
|
|
|
def handle_notice(self, notice: NoticeMessage):
|
|
relay = next((r for r in self.relays.values() if r.url == notice.url))
|
|
if relay:
|
|
relay.add_notice(notice.content)
|
|
|
|
def _open_connection(self, relay: Relay):
|
|
self.threads[relay.url] = threading.Thread(
|
|
target=relay.connect,
|
|
name=f"{relay.url}-thread",
|
|
daemon=True,
|
|
)
|
|
self.threads[relay.url].start()
|
|
|
|
def wrap_async_queue_worker():
|
|
asyncio.run(relay.queue_worker())
|
|
|
|
self.queue_threads[relay.url] = threading.Thread(
|
|
target=wrap_async_queue_worker,
|
|
name=f"{relay.url}-queue",
|
|
daemon=True,
|
|
)
|
|
self.queue_threads[relay.url].start()
|
|
|
|
def _restart_relay(self, relay: Relay):
|
|
time_since_last_error = time.time() - relay.last_error_date
|
|
|
|
min_wait_time = min(
|
|
60 * relay.error_counter, 60 * 60
|
|
) # try at least once an hour
|
|
if time_since_last_error < min_wait_time:
|
|
return
|
|
|
|
logger.info(f"Restarting connection to relay '{relay.url}'")
|
|
|
|
self.remove_relay(relay.url)
|
|
new_relay = self.add_relay(relay.url)
|
|
new_relay.error_counter = relay.error_counter
|
|
new_relay.error_list = relay.error_list
|