diff --git a/client_manager.py b/client_manager.py index 114b774..51791bf 100644 --- a/client_manager.py +++ b/client_manager.py @@ -7,20 +7,24 @@ from loguru import logger from .crud import ( create_event, delete_events, - get_all_active_relays_ids, get_event, get_events, + get_config_for_all_active_relays, mark_events_deleted, ) -from .models import NostrEvent, NostrEventType, NostrFilter +from .models import NostrEvent, NostrEventType, NostrFilter, RelayConfig class NostrClientManager: def __init__(self: "NostrClientManager"): self.clients: dict = {} - self.active_relays: Optional[List[str]] = None + self.active_relays: dict = {} + self.is_ready = False async def add_client(self, client: "NostrClientConnection") -> bool: + if not self.is_ready: + return False + allow_connect = await self._allow_client_to_connect(client.relay_id, client.websocket) if not allow_connect: return False @@ -29,22 +33,26 @@ class NostrClientManager: return True - def remove_client(self, client: "NostrClientConnection"): - self._clients(client.relay_id).remove(client) + def remove_client(self, c: "NostrClientConnection"): + self._clients(c.relay_id).remove(c) async def broadcast_event(self, source: "NostrClientConnection", event: NostrEvent): for client in self._clients(source.relay_id): if client != source: await client.notify_event(event) - async def toggle_relay(self, relay_id: str, active: bool): - if not self.active_relays: - self.active_relays = await get_all_active_relays_ids() - if active: - self.active_relays.append(relay_id) - else: - self.active_relays = [r for r in self.active_relays if r != relay_id] - await self._stop_clients_for_relay(relay_id) + async def init_relays(self): + self.active_relays = await get_config_for_all_active_relays() + self.is_ready = True + + async def enable_relay(self, relay_id: str, config: RelayConfig): + self.is_ready = True + self.active_relays[relay_id] = config + + async def disable_relay(self, relay_id: str): + await self._stop_clients_for_relay(relay_id) + del self.active_relays[relay_id] + async def _stop_clients_for_relay(self, relay_id: str): for client in self._clients(relay_id): @@ -57,9 +65,6 @@ class NostrClientManager: return self.clients[relay_id] async def _allow_client_to_connect(self, relay_id:str, websocket: WebSocket) -> bool: - if not self.active_relays: - self.active_relays = await get_all_active_relays_ids() - if relay_id not in self.active_relays: await websocket.close(reason=f"Relay '{relay_id}' is not active") return False diff --git a/crud.py b/crud.py index 8e81cde..9d36059 100644 --- a/crud.py +++ b/crud.py @@ -4,7 +4,7 @@ from typing import Any, List, Optional from lnbits.helpers import urlsafe_short_hash from . import db -from .models import NostrEvent, NostrFilter, NostrRelay +from .models import NostrEvent, NostrFilter, NostrRelay, RelayConfig ########################## RELAYS #################### @@ -42,9 +42,13 @@ async def get_relays(user_id: str) -> List[NostrRelay]: return [NostrRelay.from_row(row) for row in rows] -async def get_all_active_relays_ids() -> List[str]: - rows = await db.fetchall("SELECT id FROM nostrrelay.relays WHERE active = true",) - return [r["id"] for r in rows] +async def get_config_for_all_active_relays() -> dict: + rows = await db.fetchall("SELECT id, meta FROM nostrrelay.relays WHERE active = true",) + active_relay_configs = {} + for r in rows: + active_relay_configs[r["id"]] = RelayConfig(**json.loads(r["meta"])) #todo: from_json + + return active_relay_configs async def get_public_relay(relay_id: str) -> Optional[dict]: row = await db.fetchone("""SELECT * FROM nostrrelay.relays WHERE id = ?""", (relay_id,)) diff --git a/tests/test_clients.py b/tests/test_clients.py index 205e69d..73469eb 100644 --- a/tests/test_clients.py +++ b/tests/test_clients.py @@ -10,6 +10,7 @@ from lnbits.extensions.nostrrelay.client_manager import ( NostrClientConnection, NostrClientManager, ) +from lnbits.extensions.nostrrelay.models import RelayConfig from .helpers import get_fixtures @@ -72,8 +73,7 @@ async def test_alice_and_bob(): async def init_clients(): client_manager = NostrClientManager() - client_manager.active_relays = [RELAY_ID] - client_manager.toggle_relay(RELAY_ID, True) + await client_manager.enable_relay(RELAY_ID, RelayConfig()) ws_alice = MockWebSocket() client_alice = NostrClientConnection(relay_id=RELAY_ID, websocket=ws_alice) diff --git a/views_api.py b/views_api.py index 8c9c97f..f45c11a 100644 --- a/views_api.py +++ b/views_api.py @@ -95,7 +95,12 @@ async def api_update_relay(relay_id: str, data: NostrRelay, wallet: WalletTypeIn ) updated_relay = NostrRelay.parse_obj({**dict(relay), **dict(data)}) updated_relay = await update_relay(wallet.wallet.user, updated_relay) - await client_manager.toggle_relay(relay_id, updated_relay.active) + + if updated_relay.active: + await client_manager.enable_relay(relay_id, updated_relay.config) + else: + await client_manager.disable_relay(relay_id) + return updated_relay except HTTPException as ex: @@ -139,7 +144,7 @@ async def api_get_relay(relay_id: str, wallet: WalletTypeInfo = Depends(require_ @nostrrelay_ext.delete("/api/v1/relay/{relay_id}") async def api_delete_relay(relay_id: str, wallet: WalletTypeInfo = Depends(require_admin_key)): try: - await client_manager.toggle_relay(relay_id, False) + await client_manager.disable_relay(relay_id) await delete_relay(wallet.wallet.user, relay_id) await delete_all_events(relay_id) except Exception as ex: