feat: block access to deactivated client

This commit is contained in:
Vlad Stan 2023-02-06 17:42:27 +02:00
parent dedcf823bd
commit f56e9e2e56
3 changed files with 36 additions and 6 deletions

View file

@ -1,5 +1,5 @@
import json import json
from typing import Any, Callable, List from typing import Any, Callable, List, Optional
from fastapi import WebSocket from fastapi import WebSocket
from loguru import logger from loguru import logger
@ -7,6 +7,7 @@ from loguru import logger
from .crud import ( from .crud import (
create_event, create_event,
delete_events, delete_events,
get_all_active_relays_ids,
get_event, get_event,
get_events, get_events,
mark_events_deleted, mark_events_deleted,
@ -15,13 +16,19 @@ from .models import NostrEvent, NostrEventType, NostrFilter
class NostrClientManager: class NostrClientManager:
def __init__(self): def __init__(self: "NostrClientManager"):
self.clients: List["NostrClientConnection"] = [] self.clients: List["NostrClientConnection"] = []
self.active_relays: Optional[List[str]] = None
def add_client(self, client: "NostrClientConnection"): async def add_client(self, client: "NostrClientConnection") -> bool:
allow_connect = await self.allow_client_to_connect(client.relay_id, client.websocket)
if not allow_connect:
return False
setattr(client, "broadcast_event", self.broadcast_event) setattr(client, "broadcast_event", self.broadcast_event)
self.clients.append(client) self.clients.append(client)
return True
def remove_client(self, client: "NostrClientConnection"): def remove_client(self, client: "NostrClientConnection"):
self.clients.remove(client) self.clients.remove(client)
@ -30,6 +37,23 @@ class NostrClientManager:
if client != source: if client != source:
await client.notify_event(event) await client.notify_event(event)
async def allow_client_to_connect(self, relay_id:str, websocket: WebSocket) -> bool:
if not self.active_relays:
self.active_relays = await get_all_active_relays_ids()
if relay_id not in self.active_relays:
await websocket.close(reason=f"Relay '{relay_id}' is not active")
return False
return True
async def toggle_relay(self, relay_id: str, active: bool):
if not self.active_relays:
self.active_relays = await get_all_active_relays_ids()
if active:
self.active_relays.append(relay_id)
else:
self.active_relays = [r for r in self.active_relays if r != relay_id]
class NostrClientConnection: class NostrClientConnection:
broadcast_event: Callable broadcast_event: Callable

View file

@ -42,6 +42,9 @@ async def get_relays(user_id: str) -> List[NostrRelay]:
return [NostrRelay.from_row(row) for row in rows] return [NostrRelay.from_row(row) for row in rows]
async def get_all_active_relays_ids() -> List[str]:
rows = await db.fetchall("SELECT id FROM nostrrelay.relays WHERE active = true",)
return [r["id"] for r in rows]
async def get_public_relay(relay_id: str) -> Optional[dict]: async def get_public_relay(relay_id: str) -> Optional[dict]:
row = await db.fetchone("""SELECT * FROM nostrrelay.relays WHERE id = ?""", (relay_id,)) row = await db.fetchone("""SELECT * FROM nostrrelay.relays WHERE id = ?""", (relay_id,))

View file

@ -1,7 +1,7 @@
from http import HTTPStatus from http import HTTPStatus
from typing import List, Optional from typing import List, Optional
from fastapi import Depends, Query, WebSocket from fastapi import Depends, WebSocket
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from loguru import logger from loguru import logger
@ -28,12 +28,13 @@ from .crud import (
from .models import NostrRelay from .models import NostrRelay
client_manager = NostrClientManager() client_manager = NostrClientManager()
active_relays: List[str] = []
@nostrrelay_ext.websocket("/{relay_id}") @nostrrelay_ext.websocket("/{relay_id}")
async def websocket_endpoint(relay_id: str, websocket: WebSocket): async def websocket_endpoint(relay_id: str, websocket: WebSocket):
client = NostrClientConnection(relay_id=relay_id, websocket=websocket) client = NostrClientConnection(relay_id=relay_id, websocket=websocket)
client_manager.add_client(client) if not (await client_manager.add_client(client)):
return
try: try:
await client.start() await client.start()
except Exception as e: except Exception as e:
@ -41,6 +42,7 @@ async def websocket_endpoint(relay_id: str, websocket: WebSocket):
client_manager.remove_client(client) client_manager.remove_client(client)
@nostrrelay_ext.get("/{relay_id}", status_code=HTTPStatus.OK) @nostrrelay_ext.get("/{relay_id}", status_code=HTTPStatus.OK)
async def api_nostrrelay_info(relay_id: str): async def api_nostrrelay_info(relay_id: str):
relay = await get_public_relay(relay_id) relay = await get_public_relay(relay_id)
@ -93,6 +95,7 @@ async def api_update_relay(relay_id: str, data: NostrRelay, wallet: WalletTypeIn
) )
updated_relay = NostrRelay.parse_obj({**dict(relay), **dict(data)}) updated_relay = NostrRelay.parse_obj({**dict(relay), **dict(data)})
updated_relay = await update_relay(wallet.wallet.user, updated_relay) updated_relay = await update_relay(wallet.wallet.user, updated_relay)
await client_manager.toggle_relay(relay_id, updated_relay.active)
return updated_relay return updated_relay
except HTTPException as ex: except HTTPException as ex: