From f56e9e2e56566a43a6de920a712244b1bc5969aa Mon Sep 17 00:00:00 2001 From: Vlad Stan Date: Mon, 6 Feb 2023 17:42:27 +0200 Subject: [PATCH] feat: block access to deactivated client --- client_manager.py | 30 +++++++++++++++++++++++++++--- crud.py | 3 +++ views_api.py | 9 ++++++--- 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/client_manager.py b/client_manager.py index 4d85185..b4d5b73 100644 --- a/client_manager.py +++ b/client_manager.py @@ -1,5 +1,5 @@ import json -from typing import Any, Callable, List +from typing import Any, Callable, List, Optional from fastapi import WebSocket from loguru import logger @@ -7,6 +7,7 @@ from loguru import logger from .crud import ( create_event, delete_events, + get_all_active_relays_ids, get_event, get_events, mark_events_deleted, @@ -15,13 +16,19 @@ from .models import NostrEvent, NostrEventType, NostrFilter class NostrClientManager: - def __init__(self): + def __init__(self: "NostrClientManager"): self.clients: List["NostrClientConnection"] = [] + self.active_relays: Optional[List[str]] = None - def add_client(self, client: "NostrClientConnection"): + async def add_client(self, client: "NostrClientConnection") -> bool: + allow_connect = await self.allow_client_to_connect(client.relay_id, client.websocket) + if not allow_connect: + return False setattr(client, "broadcast_event", self.broadcast_event) self.clients.append(client) + return True + def remove_client(self, client: "NostrClientConnection"): self.clients.remove(client) @@ -30,6 +37,23 @@ class NostrClientManager: if client != source: await client.notify_event(event) + async def allow_client_to_connect(self, relay_id:str, websocket: WebSocket) -> bool: + if not self.active_relays: + self.active_relays = await get_all_active_relays_ids() + + if relay_id not in self.active_relays: + await websocket.close(reason=f"Relay '{relay_id}' is not active") + return False + return True + + async def toggle_relay(self, relay_id: str, active: bool): + if not self.active_relays: + self.active_relays = await get_all_active_relays_ids() + if active: + self.active_relays.append(relay_id) + else: + self.active_relays = [r for r in self.active_relays if r != relay_id] + class NostrClientConnection: broadcast_event: Callable diff --git a/crud.py b/crud.py index 42d4db1..762d910 100644 --- a/crud.py +++ b/crud.py @@ -42,6 +42,9 @@ async def get_relays(user_id: str) -> List[NostrRelay]: return [NostrRelay.from_row(row) for row in rows] +async def get_all_active_relays_ids() -> List[str]: + rows = await db.fetchall("SELECT id FROM nostrrelay.relays WHERE active = true",) + return [r["id"] for r in rows] async def get_public_relay(relay_id: str) -> Optional[dict]: row = await db.fetchone("""SELECT * FROM nostrrelay.relays WHERE id = ?""", (relay_id,)) diff --git a/views_api.py b/views_api.py index ea1369c..93dddca 100644 --- a/views_api.py +++ b/views_api.py @@ -1,7 +1,7 @@ from http import HTTPStatus from typing import List, Optional -from fastapi import Depends, Query, WebSocket +from fastapi import Depends, WebSocket from fastapi.exceptions import HTTPException from fastapi.responses import JSONResponse from loguru import logger @@ -28,12 +28,13 @@ from .crud import ( from .models import NostrRelay client_manager = NostrClientManager() -active_relays: List[str] = [] @nostrrelay_ext.websocket("/{relay_id}") async def websocket_endpoint(relay_id: str, websocket: WebSocket): client = NostrClientConnection(relay_id=relay_id, websocket=websocket) - client_manager.add_client(client) + if not (await client_manager.add_client(client)): + return + try: await client.start() except Exception as e: @@ -41,6 +42,7 @@ async def websocket_endpoint(relay_id: str, websocket: WebSocket): client_manager.remove_client(client) + @nostrrelay_ext.get("/{relay_id}", status_code=HTTPStatus.OK) async def api_nostrrelay_info(relay_id: str): relay = await get_public_relay(relay_id) @@ -93,6 +95,7 @@ async def api_update_relay(relay_id: str, data: NostrRelay, wallet: WalletTypeIn ) updated_relay = NostrRelay.parse_obj({**dict(relay), **dict(data)}) updated_relay = await update_relay(wallet.wallet.user, updated_relay) + await client_manager.toggle_relay(relay_id, updated_relay.active) return updated_relay except HTTPException as ex: