diff --git a/client_manager.py b/client_manager.py index 07c3cd5..8bd0c11 100644 --- a/client_manager.py +++ b/client_manager.py @@ -1,5 +1,5 @@ import json -from typing import Any, Callable, List, Optional +from typing import Any, Awaitable, Callable, List, Optional from fastapi import WebSocket from loguru import logger @@ -12,7 +12,7 @@ from .crud import ( get_events, mark_events_deleted, ) -from .models import NostrEvent, NostrEventType, NostrFilter, RelayConfig +from .models import ClientConfig, NostrEvent, NostrEventType, NostrFilter, RelayConfig class NostrClientManager: @@ -28,7 +28,12 @@ class NostrClientManager: allow_connect = await self._allow_client(client) if not allow_connect: return False + setattr(client, "broadcast_event", self.broadcast_event) + def get_client_config() -> ClientConfig: + return self.get_relay_config(client.relay_id) + setattr(client, "get_client_config", get_client_config) + self.clients(client.relay_id).append(client) return True @@ -52,18 +57,20 @@ class NostrClientManager: async def disable_relay(self, relay_id: str): await self._stop_clients_for_relay(relay_id) del self._active_relays[relay_id] + + def get_relay_config(self, relay_id: str) -> RelayConfig: + return self._active_relays[relay_id] + def clients(self, relay_id: str) -> List["NostrClientConnection"]: + if relay_id not in self._clients: + self._clients[relay_id] = [] + return self._clients[relay_id] async def _stop_clients_for_relay(self, relay_id: str): for client in self.clients(relay_id): if client.relay_id == relay_id: await client.stop(reason=f"Relay '{relay_id}' has been deactivated.") - def clients(self, relay_id: str) -> List["NostrClientConnection"]: - if relay_id not in self._clients: - self._clients[relay_id] = [] - return self._clients[relay_id] - async def _allow_client(self, c: "NostrClientConnection") -> bool: if c.relay_id not in self._active_relays: await c.stop(reason=f"Relay '{c.relay_id}' is not active") @@ -77,7 +84,8 @@ class NostrClientConnection: self.websocket = websocket self.relay_id = relay_id self.filters: List[NostrFilter] = [] - self.broadcast_event: Optional[Callable] = None + self.broadcast_event: Optional[Callable[[NostrClientConnection, NostrEvent], Awaitable[None]]] = None + self.get_client_config: Optional[Callable[[], ClientConfig]] = None async def start(self): await self.websocket.accept() @@ -134,6 +142,10 @@ class NostrClientConnection: resp_nip20: List[Any] = ["OK", e.id] try: e.check_signature() + + if not self.client_config.is_author_allowed(e.pubkey): + raise ValueError(f"Public key '{e.pubkey}' is not allowed in relay '{self.relay_id}'!") + if e.is_replaceable_event(): await delete_events( self.relay_id, NostrFilter(kinds=[e.kind], authors=[e.pubkey]) @@ -144,7 +156,9 @@ class NostrClientConnection: if e.is_delete_event(): await self._handle_delete_event(e) resp_nip20 += [True, ""] - except ValueError: + except ValueError as ex: + #todo: handle the other Value Errors + logger.debug(ex) resp_nip20 += [False, "invalid: wrong event `id` or `sig`"] except Exception as ex: logger.debug(ex) @@ -154,6 +168,12 @@ class NostrClientConnection: await self.websocket.send_text(json.dumps(resp_nip20)) + @property + def client_config(self) -> ClientConfig: + if not self.get_client_config: + raise Exception("Client not ready!") + return self.get_client_config() + async def _handle_delete_event(self, event: NostrEvent): # NIP 09 filter = NostrFilter(authors=[event.pubkey]) diff --git a/models.py b/models.py index 655538f..57e639a 100644 --- a/models.py +++ b/models.py @@ -8,12 +8,19 @@ from pydantic import BaseModel, Field from secp256k1 import PublicKey - class ClientConfig(BaseModel): max_client_filters = Field(0, alias="maxClientFilters") allowed_public_keys = Field([], alias="allowedPublicKeys") blocked_public_keys = Field([], alias="blockedPublicKeys") + def is_author_allowed(self, p: str) -> bool: + if p in self.blocked_public_keys: + return False + if len(self.allowed_public_keys) == 0: + return True + # todo: check payment + return p in self.allowed_public_keys + class Config: allow_population_by_field_name = True class RelayConfig(ClientConfig):