From f97cd1dff68615df453ca52f2164b7646fc56413 Mon Sep 17 00:00:00 2001 From: Vlad Stan Date: Wed, 8 Feb 2023 08:59:19 +0200 Subject: [PATCH] fix: group clients by relay --- client_manager.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/client_manager.py b/client_manager.py index 6e8f548..634e3c1 100644 --- a/client_manager.py +++ b/client_manager.py @@ -17,7 +17,7 @@ from .models import NostrEvent, NostrEventType, NostrFilter class NostrClientManager: def __init__(self: "NostrClientManager"): - self.clients: List["NostrClientConnection"] = [] + self.clients: dict = {} self.active_relays: Optional[List[str]] = None async def add_client(self, client: "NostrClientConnection") -> bool: @@ -25,15 +25,15 @@ class NostrClientManager: if not allow_connect: return False setattr(client, "broadcast_event", self.broadcast_event) - self.clients.append(client) + self.relay_clients(client.relay_id).append(client) return True def remove_client(self, client: "NostrClientConnection"): - self.clients.remove(client) + self.relay_clients(client.relay_id).remove(client) async def broadcast_event(self, source: "NostrClientConnection", event: NostrEvent): - for client in self.clients: + for client in self.relay_clients(source.relay_id): if client != source: await client.notify_event(event) @@ -56,11 +56,15 @@ class NostrClientManager: await self.stop_clients_for_relay(relay_id) async def stop_clients_for_relay(self, relay_id: str): - for client in self.clients: + for client in self.relay_clients(relay_id): if client.relay_id == relay_id: await client.stop(reason=f"Relay '{relay_id}' has been deactivated.") - + def relay_clients(self, relay_id: str) -> List["NostrClientConnection"]: + if relay_id not in self.clients: + self.clients[relay_id] = [] + return self.clients[relay_id] + class NostrClientConnection: def __init__(self, relay_id: str, websocket: WebSocket):