feat: check access control list

This commit is contained in:
Vlad Stan 2023-02-08 11:16:34 +02:00
parent e1fe19b115
commit 43dc3e85ce
2 changed files with 37 additions and 10 deletions

View file

@ -1,5 +1,5 @@
import json import json
from typing import Any, Callable, List, Optional from typing import Any, Awaitable, Callable, List, Optional
from fastapi import WebSocket from fastapi import WebSocket
from loguru import logger from loguru import logger
@ -12,7 +12,7 @@ from .crud import (
get_events, get_events,
mark_events_deleted, mark_events_deleted,
) )
from .models import NostrEvent, NostrEventType, NostrFilter, RelayConfig from .models import ClientConfig, NostrEvent, NostrEventType, NostrFilter, RelayConfig
class NostrClientManager: class NostrClientManager:
@ -28,7 +28,12 @@ class NostrClientManager:
allow_connect = await self._allow_client(client) allow_connect = await self._allow_client(client)
if not allow_connect: if not allow_connect:
return False return False
setattr(client, "broadcast_event", self.broadcast_event) setattr(client, "broadcast_event", self.broadcast_event)
def get_client_config() -> ClientConfig:
return self.get_relay_config(client.relay_id)
setattr(client, "get_client_config", get_client_config)
self.clients(client.relay_id).append(client) self.clients(client.relay_id).append(client)
return True return True
@ -52,18 +57,20 @@ class NostrClientManager:
async def disable_relay(self, relay_id: str): async def disable_relay(self, relay_id: str):
await self._stop_clients_for_relay(relay_id) await self._stop_clients_for_relay(relay_id)
del self._active_relays[relay_id] del self._active_relays[relay_id]
def get_relay_config(self, relay_id: str) -> RelayConfig:
return self._active_relays[relay_id]
def clients(self, relay_id: str) -> List["NostrClientConnection"]:
if relay_id not in self._clients:
self._clients[relay_id] = []
return self._clients[relay_id]
async def _stop_clients_for_relay(self, relay_id: str): async def _stop_clients_for_relay(self, relay_id: str):
for client in self.clients(relay_id): for client in self.clients(relay_id):
if client.relay_id == relay_id: if client.relay_id == relay_id:
await client.stop(reason=f"Relay '{relay_id}' has been deactivated.") await client.stop(reason=f"Relay '{relay_id}' has been deactivated.")
def clients(self, relay_id: str) -> List["NostrClientConnection"]:
if relay_id not in self._clients:
self._clients[relay_id] = []
return self._clients[relay_id]
async def _allow_client(self, c: "NostrClientConnection") -> bool: async def _allow_client(self, c: "NostrClientConnection") -> bool:
if c.relay_id not in self._active_relays: if c.relay_id not in self._active_relays:
await c.stop(reason=f"Relay '{c.relay_id}' is not active") await c.stop(reason=f"Relay '{c.relay_id}' is not active")
@ -77,7 +84,8 @@ class NostrClientConnection:
self.websocket = websocket self.websocket = websocket
self.relay_id = relay_id self.relay_id = relay_id
self.filters: List[NostrFilter] = [] self.filters: List[NostrFilter] = []
self.broadcast_event: Optional[Callable] = None self.broadcast_event: Optional[Callable[[NostrClientConnection, NostrEvent], Awaitable[None]]] = None
self.get_client_config: Optional[Callable[[], ClientConfig]] = None
async def start(self): async def start(self):
await self.websocket.accept() await self.websocket.accept()
@ -134,6 +142,10 @@ class NostrClientConnection:
resp_nip20: List[Any] = ["OK", e.id] resp_nip20: List[Any] = ["OK", e.id]
try: try:
e.check_signature() e.check_signature()
if not self.client_config.is_author_allowed(e.pubkey):
raise ValueError(f"Public key '{e.pubkey}' is not allowed in relay '{self.relay_id}'!")
if e.is_replaceable_event(): if e.is_replaceable_event():
await delete_events( await delete_events(
self.relay_id, NostrFilter(kinds=[e.kind], authors=[e.pubkey]) self.relay_id, NostrFilter(kinds=[e.kind], authors=[e.pubkey])
@ -144,7 +156,9 @@ class NostrClientConnection:
if e.is_delete_event(): if e.is_delete_event():
await self._handle_delete_event(e) await self._handle_delete_event(e)
resp_nip20 += [True, ""] resp_nip20 += [True, ""]
except ValueError: except ValueError as ex:
#todo: handle the other Value Errors
logger.debug(ex)
resp_nip20 += [False, "invalid: wrong event `id` or `sig`"] resp_nip20 += [False, "invalid: wrong event `id` or `sig`"]
except Exception as ex: except Exception as ex:
logger.debug(ex) logger.debug(ex)
@ -154,6 +168,12 @@ class NostrClientConnection:
await self.websocket.send_text(json.dumps(resp_nip20)) await self.websocket.send_text(json.dumps(resp_nip20))
@property
def client_config(self) -> ClientConfig:
if not self.get_client_config:
raise Exception("Client not ready!")
return self.get_client_config()
async def _handle_delete_event(self, event: NostrEvent): async def _handle_delete_event(self, event: NostrEvent):
# NIP 09 # NIP 09
filter = NostrFilter(authors=[event.pubkey]) filter = NostrFilter(authors=[event.pubkey])

View file

@ -8,12 +8,19 @@ from pydantic import BaseModel, Field
from secp256k1 import PublicKey from secp256k1 import PublicKey
class ClientConfig(BaseModel): class ClientConfig(BaseModel):
max_client_filters = Field(0, alias="maxClientFilters") max_client_filters = Field(0, alias="maxClientFilters")
allowed_public_keys = Field([], alias="allowedPublicKeys") allowed_public_keys = Field([], alias="allowedPublicKeys")
blocked_public_keys = Field([], alias="blockedPublicKeys") blocked_public_keys = Field([], alias="blockedPublicKeys")
def is_author_allowed(self, p: str) -> bool:
if p in self.blocked_public_keys:
return False
if len(self.allowed_public_keys) == 0:
return True
# todo: check payment
return p in self.allowed_public_keys
class Config: class Config:
allow_population_by_field_name = True allow_population_by_field_name = True
class RelayConfig(ClientConfig): class RelayConfig(ClientConfig):