fix: do not lose subscriptions if no relay

This commit is contained in:
Vlad Stan 2023-06-22 15:15:00 +03:00
parent c0632cabe5
commit af14e1c47b
2 changed files with 10 additions and 3 deletions

View file

@ -1,4 +1,5 @@
from asyncio import Lock
import ssl import ssl
import threading import threading
@ -18,6 +19,8 @@ class RelayManager:
self.threads: dict[str, threading.Thread] = {} self.threads: dict[str, threading.Thread] = {}
self.queue_threads: dict[str, threading.Thread] = {} self.queue_threads: dict[str, threading.Thread] = {}
self.message_pool = MessagePool() self.message_pool = MessagePool()
self._cached_subscriptions = dict[str, Subscription] = {}
self._subscriptions_lock = Lock()
def add_relay( def add_relay(
self, url: str, read: bool = True, write: bool = True, subscriptions: dict[str, Subscription] = {} self, url: str, read: bool = True, write: bool = True, subscriptions: dict[str, Subscription] = {}
@ -46,10 +49,16 @@ class RelayManager:
self.relays.pop(url) self.relays.pop(url)
def add_subscription(self, id: str, filters: Filters): def add_subscription(self, id: str, filters: Filters):
with self._subscriptions_lock:
self._cached_subscriptions[id] = Subscription(id, filters)
for relay in self.relays.values(): for relay in self.relays.values():
relay.add_subscription(id, filters) relay.add_subscription(id, filters)
def close_subscription(self, id: str): def close_subscription(self, id: str):
with self._subscriptions_lock:
self._cached_subscriptions.pop(id)
for relay in self.relays.values(): for relay in self.relays.values():
relay.close_subscription(id) relay.close_subscription(id)

View file

@ -62,11 +62,9 @@ async def api_add_relay(relay: Relay) -> Optional[RelayList]:
all_relays: List[NostrRelay] = nostr.client.relay_manager.relays.values() all_relays: List[NostrRelay] = nostr.client.relay_manager.relays.values()
if len(all_relays): if len(all_relays):
subscriptions = all_relays[0].subscriptions
nostr.client.relays.append(relay.url) nostr.client.relays.append(relay.url)
nostr.client.relay_manager.add_relay(subscriptions) nostr.client.relay_manager.add_relay()
nostr.client.relay_manager.connect_relay(relay.url)
return await get_relays() return await get_relays()