feat: check access control list
This commit is contained in:
parent
e1fe19b115
commit
43dc3e85ce
2 changed files with 37 additions and 10 deletions
|
|
@ -1,5 +1,5 @@
|
||||||
import json
|
import json
|
||||||
from typing import Any, Callable, List, Optional
|
from typing import Any, Awaitable, Callable, List, Optional
|
||||||
|
|
||||||
from fastapi import WebSocket
|
from fastapi import WebSocket
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
@ -12,7 +12,7 @@ from .crud import (
|
||||||
get_events,
|
get_events,
|
||||||
mark_events_deleted,
|
mark_events_deleted,
|
||||||
)
|
)
|
||||||
from .models import NostrEvent, NostrEventType, NostrFilter, RelayConfig
|
from .models import ClientConfig, NostrEvent, NostrEventType, NostrFilter, RelayConfig
|
||||||
|
|
||||||
|
|
||||||
class NostrClientManager:
|
class NostrClientManager:
|
||||||
|
|
@ -28,7 +28,12 @@ class NostrClientManager:
|
||||||
allow_connect = await self._allow_client(client)
|
allow_connect = await self._allow_client(client)
|
||||||
if not allow_connect:
|
if not allow_connect:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
setattr(client, "broadcast_event", self.broadcast_event)
|
setattr(client, "broadcast_event", self.broadcast_event)
|
||||||
|
def get_client_config() -> ClientConfig:
|
||||||
|
return self.get_relay_config(client.relay_id)
|
||||||
|
setattr(client, "get_client_config", get_client_config)
|
||||||
|
|
||||||
self.clients(client.relay_id).append(client)
|
self.clients(client.relay_id).append(client)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
@ -52,18 +57,20 @@ class NostrClientManager:
|
||||||
async def disable_relay(self, relay_id: str):
|
async def disable_relay(self, relay_id: str):
|
||||||
await self._stop_clients_for_relay(relay_id)
|
await self._stop_clients_for_relay(relay_id)
|
||||||
del self._active_relays[relay_id]
|
del self._active_relays[relay_id]
|
||||||
|
|
||||||
|
def get_relay_config(self, relay_id: str) -> RelayConfig:
|
||||||
|
return self._active_relays[relay_id]
|
||||||
|
|
||||||
|
def clients(self, relay_id: str) -> List["NostrClientConnection"]:
|
||||||
|
if relay_id not in self._clients:
|
||||||
|
self._clients[relay_id] = []
|
||||||
|
return self._clients[relay_id]
|
||||||
|
|
||||||
async def _stop_clients_for_relay(self, relay_id: str):
|
async def _stop_clients_for_relay(self, relay_id: str):
|
||||||
for client in self.clients(relay_id):
|
for client in self.clients(relay_id):
|
||||||
if client.relay_id == relay_id:
|
if client.relay_id == relay_id:
|
||||||
await client.stop(reason=f"Relay '{relay_id}' has been deactivated.")
|
await client.stop(reason=f"Relay '{relay_id}' has been deactivated.")
|
||||||
|
|
||||||
def clients(self, relay_id: str) -> List["NostrClientConnection"]:
|
|
||||||
if relay_id not in self._clients:
|
|
||||||
self._clients[relay_id] = []
|
|
||||||
return self._clients[relay_id]
|
|
||||||
|
|
||||||
async def _allow_client(self, c: "NostrClientConnection") -> bool:
|
async def _allow_client(self, c: "NostrClientConnection") -> bool:
|
||||||
if c.relay_id not in self._active_relays:
|
if c.relay_id not in self._active_relays:
|
||||||
await c.stop(reason=f"Relay '{c.relay_id}' is not active")
|
await c.stop(reason=f"Relay '{c.relay_id}' is not active")
|
||||||
|
|
@ -77,7 +84,8 @@ class NostrClientConnection:
|
||||||
self.websocket = websocket
|
self.websocket = websocket
|
||||||
self.relay_id = relay_id
|
self.relay_id = relay_id
|
||||||
self.filters: List[NostrFilter] = []
|
self.filters: List[NostrFilter] = []
|
||||||
self.broadcast_event: Optional[Callable] = None
|
self.broadcast_event: Optional[Callable[[NostrClientConnection, NostrEvent], Awaitable[None]]] = None
|
||||||
|
self.get_client_config: Optional[Callable[[], ClientConfig]] = None
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
await self.websocket.accept()
|
await self.websocket.accept()
|
||||||
|
|
@ -134,6 +142,10 @@ class NostrClientConnection:
|
||||||
resp_nip20: List[Any] = ["OK", e.id]
|
resp_nip20: List[Any] = ["OK", e.id]
|
||||||
try:
|
try:
|
||||||
e.check_signature()
|
e.check_signature()
|
||||||
|
|
||||||
|
if not self.client_config.is_author_allowed(e.pubkey):
|
||||||
|
raise ValueError(f"Public key '{e.pubkey}' is not allowed in relay '{self.relay_id}'!")
|
||||||
|
|
||||||
if e.is_replaceable_event():
|
if e.is_replaceable_event():
|
||||||
await delete_events(
|
await delete_events(
|
||||||
self.relay_id, NostrFilter(kinds=[e.kind], authors=[e.pubkey])
|
self.relay_id, NostrFilter(kinds=[e.kind], authors=[e.pubkey])
|
||||||
|
|
@ -144,7 +156,9 @@ class NostrClientConnection:
|
||||||
if e.is_delete_event():
|
if e.is_delete_event():
|
||||||
await self._handle_delete_event(e)
|
await self._handle_delete_event(e)
|
||||||
resp_nip20 += [True, ""]
|
resp_nip20 += [True, ""]
|
||||||
except ValueError:
|
except ValueError as ex:
|
||||||
|
#todo: handle the other Value Errors
|
||||||
|
logger.debug(ex)
|
||||||
resp_nip20 += [False, "invalid: wrong event `id` or `sig`"]
|
resp_nip20 += [False, "invalid: wrong event `id` or `sig`"]
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.debug(ex)
|
logger.debug(ex)
|
||||||
|
|
@ -154,6 +168,12 @@ class NostrClientConnection:
|
||||||
|
|
||||||
await self.websocket.send_text(json.dumps(resp_nip20))
|
await self.websocket.send_text(json.dumps(resp_nip20))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client_config(self) -> ClientConfig:
|
||||||
|
if not self.get_client_config:
|
||||||
|
raise Exception("Client not ready!")
|
||||||
|
return self.get_client_config()
|
||||||
|
|
||||||
async def _handle_delete_event(self, event: NostrEvent):
|
async def _handle_delete_event(self, event: NostrEvent):
|
||||||
# NIP 09
|
# NIP 09
|
||||||
filter = NostrFilter(authors=[event.pubkey])
|
filter = NostrFilter(authors=[event.pubkey])
|
||||||
|
|
|
||||||
|
|
@ -8,12 +8,19 @@ from pydantic import BaseModel, Field
|
||||||
from secp256k1 import PublicKey
|
from secp256k1 import PublicKey
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ClientConfig(BaseModel):
|
class ClientConfig(BaseModel):
|
||||||
max_client_filters = Field(0, alias="maxClientFilters")
|
max_client_filters = Field(0, alias="maxClientFilters")
|
||||||
allowed_public_keys = Field([], alias="allowedPublicKeys")
|
allowed_public_keys = Field([], alias="allowedPublicKeys")
|
||||||
blocked_public_keys = Field([], alias="blockedPublicKeys")
|
blocked_public_keys = Field([], alias="blockedPublicKeys")
|
||||||
|
|
||||||
|
def is_author_allowed(self, p: str) -> bool:
|
||||||
|
if p in self.blocked_public_keys:
|
||||||
|
return False
|
||||||
|
if len(self.allowed_public_keys) == 0:
|
||||||
|
return True
|
||||||
|
# todo: check payment
|
||||||
|
return p in self.allowed_public_keys
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
allow_population_by_field_name = True
|
allow_population_by_field_name = True
|
||||||
class RelayConfig(ClientConfig):
|
class RelayConfig(ClientConfig):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue