fix: init new relays with previous subscriptions

This commit is contained in:
Vlad Stan 2023-06-22 16:55:09 +03:00
parent af14e1c47b
commit 811bfdc45a
3 changed files with 12 additions and 14 deletions

View file

@ -20,9 +20,9 @@ class NostrClient:
if connect: if connect:
self.connect() self.connect()
async def connect(self, subscriptions: dict[str, Subscription] = {}): async def connect(self):
for relay in self.relays: for relay in self.relays:
self.relay_manager.add_relay(relay, subscriptions) self.relay_manager.add_relay(relay)

View file

@ -1,5 +1,4 @@
from asyncio import Lock
import ssl import ssl
import threading import threading
@ -19,17 +18,18 @@ class RelayManager:
self.threads: dict[str, threading.Thread] = {} self.threads: dict[str, threading.Thread] = {}
self.queue_threads: dict[str, threading.Thread] = {} self.queue_threads: dict[str, threading.Thread] = {}
self.message_pool = MessagePool() self.message_pool = MessagePool()
self._cached_subscriptions = dict[str, Subscription] = {} self._cached_subscriptions: dict[str, Subscription] = {}
self._subscriptions_lock = Lock() self._subscriptions_lock = threading.Lock()
def add_relay( def add_relay(self, url: str, read: bool = True, write: bool = True) -> Relay:
self, url: str, read: bool = True, write: bool = True, subscriptions: dict[str, Subscription] = {}
) -> Relay:
if url in self.relays: if url in self.relays:
return return
with self._subscriptions_lock:
subscriptions = self._cached_subscriptions.copy()
policy = RelayPolicy(read, write) policy = RelayPolicy(read, write)
relay = Relay(url, policy, self.message_pool, subscriptions.copy()) relay = Relay(url, policy, self.message_pool, subscriptions)
self.relays[url] = relay self.relays[url] = relay
self.open_connection( self.open_connection(

View file

@ -60,10 +60,8 @@ async def api_add_relay(relay: Relay) -> Optional[RelayList]:
relay.id = urlsafe_short_hash() relay.id = urlsafe_short_hash()
await add_relay(relay) await add_relay(relay)
all_relays: List[NostrRelay] = nostr.client.relay_manager.relays.values() nostr.client.relays.append(relay.url)
if len(all_relays): nostr.client.relay_manager.add_relay(relay.url)
nostr.client.relays.append(relay.url)
nostr.client.relay_manager.add_relay()
return await get_relays() return await get_relays()