feat: check access control list

This commit is contained in:
Vlad Stan 2023-02-08 11:16:34 +02:00
parent e1fe19b115
commit 43dc3e85ce
2 changed files with 37 additions and 10 deletions

View file

@ -1,5 +1,5 @@
import json
from typing import Any, Callable, List, Optional
from typing import Any, Awaitable, Callable, List, Optional
from fastapi import WebSocket
from loguru import logger
@ -12,7 +12,7 @@ from .crud import (
get_events,
mark_events_deleted,
)
from .models import NostrEvent, NostrEventType, NostrFilter, RelayConfig
from .models import ClientConfig, NostrEvent, NostrEventType, NostrFilter, RelayConfig
class NostrClientManager:
@ -28,7 +28,12 @@ class NostrClientManager:
allow_connect = await self._allow_client(client)
if not allow_connect:
return False
setattr(client, "broadcast_event", self.broadcast_event)
def get_client_config() -> ClientConfig:
return self.get_relay_config(client.relay_id)
setattr(client, "get_client_config", get_client_config)
self.clients(client.relay_id).append(client)
return True
@ -53,17 +58,19 @@ class NostrClientManager:
await self._stop_clients_for_relay(relay_id)
del self._active_relays[relay_id]
async def _stop_clients_for_relay(self, relay_id: str):
for client in self.clients(relay_id):
if client.relay_id == relay_id:
await client.stop(reason=f"Relay '{relay_id}' has been deactivated.")
def get_relay_config(self, relay_id: str) -> RelayConfig:
return self._active_relays[relay_id]
def clients(self, relay_id: str) -> List["NostrClientConnection"]:
if relay_id not in self._clients:
self._clients[relay_id] = []
return self._clients[relay_id]
async def _stop_clients_for_relay(self, relay_id: str):
for client in self.clients(relay_id):
if client.relay_id == relay_id:
await client.stop(reason=f"Relay '{relay_id}' has been deactivated.")
async def _allow_client(self, c: "NostrClientConnection") -> bool:
if c.relay_id not in self._active_relays:
await c.stop(reason=f"Relay '{c.relay_id}' is not active")
@ -77,7 +84,8 @@ class NostrClientConnection:
self.websocket = websocket
self.relay_id = relay_id
self.filters: List[NostrFilter] = []
self.broadcast_event: Optional[Callable] = None
self.broadcast_event: Optional[Callable[[NostrClientConnection, NostrEvent], Awaitable[None]]] = None
self.get_client_config: Optional[Callable[[], ClientConfig]] = None
async def start(self):
await self.websocket.accept()
@ -134,6 +142,10 @@ class NostrClientConnection:
resp_nip20: List[Any] = ["OK", e.id]
try:
e.check_signature()
if not self.client_config.is_author_allowed(e.pubkey):
raise ValueError(f"Public key '{e.pubkey}' is not allowed in relay '{self.relay_id}'!")
if e.is_replaceable_event():
await delete_events(
self.relay_id, NostrFilter(kinds=[e.kind], authors=[e.pubkey])
@ -144,7 +156,9 @@ class NostrClientConnection:
if e.is_delete_event():
await self._handle_delete_event(e)
resp_nip20 += [True, ""]
except ValueError:
except ValueError as ex:
#todo: handle the other Value Errors
logger.debug(ex)
resp_nip20 += [False, "invalid: wrong event `id` or `sig`"]
except Exception as ex:
logger.debug(ex)
@ -154,6 +168,12 @@ class NostrClientConnection:
await self.websocket.send_text(json.dumps(resp_nip20))
@property
def client_config(self) -> ClientConfig:
if not self.get_client_config:
raise Exception("Client not ready!")
return self.get_client_config()
async def _handle_delete_event(self, event: NostrEvent):
# NIP 09
filter = NostrFilter(authors=[event.pubkey])

View file

@ -8,12 +8,19 @@ from pydantic import BaseModel, Field
from secp256k1 import PublicKey
class ClientConfig(BaseModel):
max_client_filters = Field(0, alias="maxClientFilters")
allowed_public_keys = Field([], alias="allowedPublicKeys")
blocked_public_keys = Field([], alias="blockedPublicKeys")
def is_author_allowed(self, p: str) -> bool:
if p in self.blocked_public_keys:
return False
if len(self.allowed_public_keys) == 0:
return True
# todo: check payment
return p in self.allowed_public_keys
class Config:
allow_population_by_field_name = True
class RelayConfig(ClientConfig):