From 811bfdc45a707192d0c43ac91738ae76344cbc3f Mon Sep 17 00:00:00 2001 From: Vlad Stan Date: Thu, 22 Jun 2023 16:55:09 +0300 Subject: [PATCH] fix: init new relays with previous subscriptions --- nostr/client/client.py | 4 ++-- nostr/relay_manager.py | 16 ++++++++-------- views_api.py | 6 ++---- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/nostr/client/client.py b/nostr/client/client.py index 0ff85a8..7c6063e 100644 --- a/nostr/client/client.py +++ b/nostr/client/client.py @@ -20,9 +20,9 @@ class NostrClient: if connect: self.connect() - async def connect(self, subscriptions: dict[str, Subscription] = {}): + async def connect(self): for relay in self.relays: - self.relay_manager.add_relay(relay, subscriptions) + self.relay_manager.add_relay(relay) diff --git a/nostr/relay_manager.py b/nostr/relay_manager.py index 63eb68f..ddc833c 100644 --- a/nostr/relay_manager.py +++ b/nostr/relay_manager.py @@ -1,5 +1,4 @@ -from asyncio import Lock import ssl import threading @@ -19,17 +18,18 @@ class RelayManager: self.threads: dict[str, threading.Thread] = {} self.queue_threads: dict[str, threading.Thread] = {} self.message_pool = MessagePool() - self._cached_subscriptions = dict[str, Subscription] = {} - self._subscriptions_lock = Lock() + self._cached_subscriptions: dict[str, Subscription] = {} + self._subscriptions_lock = threading.Lock() - def add_relay( - self, url: str, read: bool = True, write: bool = True, subscriptions: dict[str, Subscription] = {} - ) -> Relay: + def add_relay(self, url: str, read: bool = True, write: bool = True) -> Relay: if url in self.relays: return - + + with self._subscriptions_lock: + subscriptions = self._cached_subscriptions.copy() + policy = RelayPolicy(read, write) - relay = Relay(url, policy, self.message_pool, subscriptions.copy()) + relay = Relay(url, policy, self.message_pool, subscriptions) self.relays[url] = relay self.open_connection( diff --git a/views_api.py b/views_api.py index 0036070..131da44 100644 --- a/views_api.py +++ b/views_api.py @@ -60,10 +60,8 @@ async def api_add_relay(relay: Relay) -> Optional[RelayList]: relay.id = urlsafe_short_hash() await add_relay(relay) - all_relays: List[NostrRelay] = nostr.client.relay_manager.relays.values() - if len(all_relays): - nostr.client.relays.append(relay.url) - nostr.client.relay_manager.add_relay() + nostr.client.relays.append(relay.url) + nostr.client.relay_manager.add_relay(relay.url) return await get_relays()