feat: block access to deactivated client
This commit is contained in:
parent
dedcf823bd
commit
f56e9e2e56
3 changed files with 36 additions and 6 deletions
|
|
@ -1,5 +1,5 @@
|
||||||
import json
|
import json
|
||||||
from typing import Any, Callable, List
|
from typing import Any, Callable, List, Optional
|
||||||
|
|
||||||
from fastapi import WebSocket
|
from fastapi import WebSocket
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
@ -7,6 +7,7 @@ from loguru import logger
|
||||||
from .crud import (
|
from .crud import (
|
||||||
create_event,
|
create_event,
|
||||||
delete_events,
|
delete_events,
|
||||||
|
get_all_active_relays_ids,
|
||||||
get_event,
|
get_event,
|
||||||
get_events,
|
get_events,
|
||||||
mark_events_deleted,
|
mark_events_deleted,
|
||||||
|
|
@ -15,13 +16,19 @@ from .models import NostrEvent, NostrEventType, NostrFilter
|
||||||
|
|
||||||
|
|
||||||
class NostrClientManager:
|
class NostrClientManager:
|
||||||
def __init__(self):
|
def __init__(self: "NostrClientManager"):
|
||||||
self.clients: List["NostrClientConnection"] = []
|
self.clients: List["NostrClientConnection"] = []
|
||||||
|
self.active_relays: Optional[List[str]] = None
|
||||||
|
|
||||||
def add_client(self, client: "NostrClientConnection"):
|
async def add_client(self, client: "NostrClientConnection") -> bool:
|
||||||
|
allow_connect = await self.allow_client_to_connect(client.relay_id, client.websocket)
|
||||||
|
if not allow_connect:
|
||||||
|
return False
|
||||||
setattr(client, "broadcast_event", self.broadcast_event)
|
setattr(client, "broadcast_event", self.broadcast_event)
|
||||||
self.clients.append(client)
|
self.clients.append(client)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def remove_client(self, client: "NostrClientConnection"):
|
def remove_client(self, client: "NostrClientConnection"):
|
||||||
self.clients.remove(client)
|
self.clients.remove(client)
|
||||||
|
|
||||||
|
|
@ -30,6 +37,23 @@ class NostrClientManager:
|
||||||
if client != source:
|
if client != source:
|
||||||
await client.notify_event(event)
|
await client.notify_event(event)
|
||||||
|
|
||||||
|
async def allow_client_to_connect(self, relay_id:str, websocket: WebSocket) -> bool:
|
||||||
|
if not self.active_relays:
|
||||||
|
self.active_relays = await get_all_active_relays_ids()
|
||||||
|
|
||||||
|
if relay_id not in self.active_relays:
|
||||||
|
await websocket.close(reason=f"Relay '{relay_id}' is not active")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def toggle_relay(self, relay_id: str, active: bool):
|
||||||
|
if not self.active_relays:
|
||||||
|
self.active_relays = await get_all_active_relays_ids()
|
||||||
|
if active:
|
||||||
|
self.active_relays.append(relay_id)
|
||||||
|
else:
|
||||||
|
self.active_relays = [r for r in self.active_relays if r != relay_id]
|
||||||
|
|
||||||
|
|
||||||
class NostrClientConnection:
|
class NostrClientConnection:
|
||||||
broadcast_event: Callable
|
broadcast_event: Callable
|
||||||
|
|
|
||||||
3
crud.py
3
crud.py
|
|
@ -42,6 +42,9 @@ async def get_relays(user_id: str) -> List[NostrRelay]:
|
||||||
|
|
||||||
return [NostrRelay.from_row(row) for row in rows]
|
return [NostrRelay.from_row(row) for row in rows]
|
||||||
|
|
||||||
|
async def get_all_active_relays_ids() -> List[str]:
|
||||||
|
rows = await db.fetchall("SELECT id FROM nostrrelay.relays WHERE active = true",)
|
||||||
|
return [r["id"] for r in rows]
|
||||||
|
|
||||||
async def get_public_relay(relay_id: str) -> Optional[dict]:
|
async def get_public_relay(relay_id: str) -> Optional[dict]:
|
||||||
row = await db.fetchone("""SELECT * FROM nostrrelay.relays WHERE id = ?""", (relay_id,))
|
row = await db.fetchone("""SELECT * FROM nostrrelay.relays WHERE id = ?""", (relay_id,))
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from fastapi import Depends, Query, WebSocket
|
from fastapi import Depends, WebSocket
|
||||||
from fastapi.exceptions import HTTPException
|
from fastapi.exceptions import HTTPException
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
@ -28,12 +28,13 @@ from .crud import (
|
||||||
from .models import NostrRelay
|
from .models import NostrRelay
|
||||||
|
|
||||||
client_manager = NostrClientManager()
|
client_manager = NostrClientManager()
|
||||||
active_relays: List[str] = []
|
|
||||||
|
|
||||||
@nostrrelay_ext.websocket("/{relay_id}")
|
@nostrrelay_ext.websocket("/{relay_id}")
|
||||||
async def websocket_endpoint(relay_id: str, websocket: WebSocket):
|
async def websocket_endpoint(relay_id: str, websocket: WebSocket):
|
||||||
client = NostrClientConnection(relay_id=relay_id, websocket=websocket)
|
client = NostrClientConnection(relay_id=relay_id, websocket=websocket)
|
||||||
client_manager.add_client(client)
|
if not (await client_manager.add_client(client)):
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await client.start()
|
await client.start()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -41,6 +42,7 @@ async def websocket_endpoint(relay_id: str, websocket: WebSocket):
|
||||||
client_manager.remove_client(client)
|
client_manager.remove_client(client)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@nostrrelay_ext.get("/{relay_id}", status_code=HTTPStatus.OK)
|
@nostrrelay_ext.get("/{relay_id}", status_code=HTTPStatus.OK)
|
||||||
async def api_nostrrelay_info(relay_id: str):
|
async def api_nostrrelay_info(relay_id: str):
|
||||||
relay = await get_public_relay(relay_id)
|
relay = await get_public_relay(relay_id)
|
||||||
|
|
@ -93,6 +95,7 @@ async def api_update_relay(relay_id: str, data: NostrRelay, wallet: WalletTypeIn
|
||||||
)
|
)
|
||||||
updated_relay = NostrRelay.parse_obj({**dict(relay), **dict(data)})
|
updated_relay = NostrRelay.parse_obj({**dict(relay), **dict(data)})
|
||||||
updated_relay = await update_relay(wallet.wallet.user, updated_relay)
|
updated_relay = await update_relay(wallet.wallet.user, updated_relay)
|
||||||
|
await client_manager.toggle_relay(relay_id, updated_relay.active)
|
||||||
return updated_relay
|
return updated_relay
|
||||||
|
|
||||||
except HTTPException as ex:
|
except HTTPException as ex:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue