fix: group clients by relay

This commit is contained in:
Vlad Stan 2023-02-08 08:59:19 +02:00
parent 4e5c2657c9
commit f97cd1dff6

View file

@ -17,7 +17,7 @@ from .models import NostrEvent, NostrEventType, NostrFilter
class NostrClientManager: class NostrClientManager:
def __init__(self: "NostrClientManager"): def __init__(self: "NostrClientManager"):
self.clients: List["NostrClientConnection"] = [] self.clients: dict = {}
self.active_relays: Optional[List[str]] = None self.active_relays: Optional[List[str]] = None
async def add_client(self, client: "NostrClientConnection") -> bool: async def add_client(self, client: "NostrClientConnection") -> bool:
@ -25,15 +25,15 @@ class NostrClientManager:
if not allow_connect: if not allow_connect:
return False return False
setattr(client, "broadcast_event", self.broadcast_event) setattr(client, "broadcast_event", self.broadcast_event)
self.clients.append(client) self.relay_clients(client.relay_id).append(client)
return True return True
def remove_client(self, client: "NostrClientConnection"): def remove_client(self, client: "NostrClientConnection"):
self.clients.remove(client) self.relay_clients(client.relay_id).remove(client)
async def broadcast_event(self, source: "NostrClientConnection", event: NostrEvent): async def broadcast_event(self, source: "NostrClientConnection", event: NostrEvent):
for client in self.clients: for client in self.relay_clients(source.relay_id):
if client != source: if client != source:
await client.notify_event(event) await client.notify_event(event)
@ -56,11 +56,15 @@ class NostrClientManager:
await self.stop_clients_for_relay(relay_id) await self.stop_clients_for_relay(relay_id)
async def stop_clients_for_relay(self, relay_id: str): async def stop_clients_for_relay(self, relay_id: str):
for client in self.clients: for client in self.relay_clients(relay_id):
if client.relay_id == relay_id: if client.relay_id == relay_id:
await client.stop(reason=f"Relay '{relay_id}' has been deactivated.") await client.stop(reason=f"Relay '{relay_id}' has been deactivated.")
def relay_clients(self, relay_id: str) -> List["NostrClientConnection"]:
if relay_id not in self.clients:
self.clients[relay_id] = []
return self.clients[relay_id]
class NostrClientConnection: class NostrClientConnection:
def __init__(self, relay_id: str, websocket: WebSocket): def __init__(self, relay_id: str, websocket: WebSocket):