feat: partial AUTH support

This commit is contained in:
Vlad Stan 2023-02-14 17:26:40 +02:00
parent d0c6f1392b
commit 3648dc212c
7 changed files with 141 additions and 16 deletions

View file

@ -5,6 +5,8 @@ from typing import Any, Awaitable, Callable, List, Optional, Tuple
from fastapi import WebSocket
from loguru import logger
from lnbits.helpers import urlsafe_short_hash
from .crud import (
create_event,
delete_events,
@ -74,7 +76,6 @@ class NostrClientManager:
if c.relay_id not in self._active_relays:
await c.stop(reason=f"Relay '{c.relay_id}' is not active")
return False
# todo: NIP-42: AUTH
return True
def _set_client_callbacks(self, client):
@ -91,13 +92,20 @@ class NostrClientConnection:
self.websocket = websocket
self.relay_id = relay_id
self.filters: List[NostrFilter] = []
self.authenticated = False
self.pubkey: str = None
self._auth_challenge: str = None
self._auth_challenge_created_at = 0
self._last_event_timestamp = 0 # in seconds
self._event_count_per_timestamp = 0
self.broadcast_event: Optional[
Callable[[NostrClientConnection, NostrEvent], Awaitable[None]]
] = None
self.get_client_config: Optional[Callable[[], RelaySpec]] = None
self._last_event_timestamp = 0 # in seconds
self._event_count_per_timestamp = 0
async def start(self):
await self.websocket.accept()
@ -150,13 +158,21 @@ class NostrClientConnection:
return await self._handle_request(data[1], NostrFilter.parse_obj(data[2]))
if message_type == NostrEventType.CLOSE:
self._handle_close(data[1])
if message_type == NostrEventType.AUTH:
self._handle_auth(data[1])
return []
async def _handle_event(self, e: NostrEvent):
resp_nip20: List[Any] = ["OK", e.id]
logger.info(f"nostr event: [{e.kind}, {e.pubkey}, '{e.content}']")
resp_nip20: List[Any] = ["OK", e.id]
if not self.authenticated and self.client_config.event_requires_auth(e.kind):
await self._send_msg(["AUTH", self._current_auth_challenge()])
resp_nip20 += [False, "Relay requires authentication"]
await self._send_msg(resp_nip20)
return None
valid, message = await self._validate_write(e)
if not valid:
resp_nip20 += [valid, message]
@ -201,6 +217,9 @@ class NostrClientConnection:
await mark_events_deleted(self.relay_id, NostrFilter(ids=ids))
async def _handle_request(self, subscription_id: str, filter: NostrFilter) -> List:
if not self.authenticated and self.client_config.require_auth_filter:
return [["AUTH", self._current_auth_challenge()]]
filter.subscription_id = subscription_id
self._remove_filter(subscription_id)
if self._can_add_filter():
@ -227,6 +246,9 @@ class NostrClientConnection:
def _handle_close(self, subscription_id: str):
self._remove_filter(subscription_id)
def _handle_auth(self):
raise ValueError('Not supported')
def _can_add_filter(self) -> bool:
return (
self.client_config.max_client_filters != 0
@ -318,3 +340,16 @@ class NostrClientConnection:
if created_at > (current_time + self.client_config.created_at_in_future):
return False, "created_at is too much into the future"
return True, ""
def _auth_challenge_expired(self):
if self._auth_challenge_created_at == 0:
return True
current_time_seconds = round(time.time())
chanllenge_max_age_seconds = 300 # 5 min
return (current_time_seconds - self._auth_challenge_created_at) >= chanllenge_max_age_seconds
def _current_auth_challenge(self):
if self._auth_challenge_expired():
self._auth_challenge = self.relay_id + ":" + urlsafe_short_hash()
self._auth_challenge_created_at = round(time.time())
return self._auth_challenge