import json import time from collections.abc import Awaitable, Callable from typing import Any from fastapi import WebSocket from lnbits.helpers import urlsafe_short_hash from loguru import logger from ..crud import ( NostrAccount, create_event, delete_events, get_account, get_event, get_events, mark_events_deleted, ) from .event import NostrEvent, NostrEventType from .event_validator import EventValidator from .filter import NostrFilter from .relay import RelaySpec class NostrClientConnection: def __init__(self, relay_id: str, websocket: WebSocket): self.websocket = websocket self.relay_id = relay_id self.filters: list[NostrFilter] = [] self.auth_pubkey: str | None = None # set if authenticated self._auth_challenge: str | None = None self._auth_challenge_created_at = 0 self.event_validator = EventValidator(self.relay_id) self.broadcast_event: ( Callable[[NostrClientConnection, NostrEvent], Awaitable[None]] | None ) = None self.get_client_config: Callable[[], RelaySpec] | None = None async def start(self): await self.websocket.accept() while True: json_data = await self.websocket.receive_text() try: data = json.loads(json_data) resp = await self._handle_message(data) for r in resp: await self._send_msg(r) except Exception as e: logger.warning(e) async def stop(self, reason: str | None): message = reason if reason else "Server closed webocket" try: await self._send_msg(["NOTICE", message]) except Exception: pass try: await self.websocket.close(reason=reason) except Exception: pass def init_callbacks(self, broadcast_event: Callable, get_client_config: Callable): self.broadcast_event = broadcast_event self.get_client_config = get_client_config self.event_validator.get_client_config = get_client_config async def notify_event(self, event: NostrEvent) -> bool: if self._is_direct_message_for_other(event): return False for nostr_filter in self.filters: if nostr_filter.matches(event): resp = event.serialize_response(nostr_filter.subscription_id) await self._send_msg(resp) return True return False def _is_direct_message_for_other(self, event: NostrEvent) -> bool: """ Direct messages are not inteded to be boradcast (even if encrypted). If the server requires AUTH for kind '4' then direct message will be sent only to the intended client. """ if not event.is_direct_message: return False if not self.config.event_requires_auth(event.kind): return False if not self.auth_pubkey: return True if event.has_tag_value("p", self.auth_pubkey): return False return True async def _broadcast_event(self, e: NostrEvent): if self.broadcast_event: await self.broadcast_event(self, e) async def _handle_message(self, data: list) -> list: if len(data) < 2: return [] message_type = data[0] if message_type == NostrEventType.EVENT: event_dict = { "relay_id": self.relay_id, "publisher": data[1]["pubkey"], **data[1], } event = NostrEvent(**event_dict) await self._handle_event(event) return [] if message_type == NostrEventType.REQ: if len(data) < 3: return [] subscription_id = data[1] # Handle multiple filters in REQ message responses = [] for filter_data in data[2:]: response = await self._handle_request( subscription_id, NostrFilter.parse_obj(filter_data) ) responses.extend(response) return responses if message_type == NostrEventType.CLOSE: self._handle_close(data[1]) if message_type == NostrEventType.AUTH: await self._handle_auth() return [] async def _handle_event(self, e: NostrEvent): logger.info(f"nostr event: [{e.kind}, {e.pubkey}, '{e.content}']") resp_nip20: list[Any] = ["OK", e.id] if e.is_auth_response_event: valid, message = self.event_validator.validate_auth_event( e, self._auth_challenge ) if not valid: resp_nip20 += [valid, message] await self._send_msg(resp_nip20) return None self.auth_pubkey = e.pubkey if not self.auth_pubkey and self.config.event_requires_auth(e.kind): await self._send_msg(["AUTH", self._current_auth_challenge()]) resp_nip20 += [ False, f"Relay requires authentication for events of kind '{e.kind}'", ] await self._send_msg(resp_nip20) return None publisher_pubkey = self.auth_pubkey if self.auth_pubkey else e.pubkey valid, message = await self.event_validator.validate_write(e, publisher_pubkey) if not valid: resp_nip20 += [valid, message] await self._send_msg(resp_nip20) return None try: if e.is_replaceable_event: await delete_events( self.relay_id, NostrFilter(kinds=[e.kind], authors=[e.pubkey], until=e.created_at), ) if e.is_addressable_event: # Extract 'd' tag value for addressable replacement (NIP-01) d_tag_value = next((t[1] for t in e.tags if t[0] == "d"), None) if d_tag_value: deletion_filter = NostrFilter( kinds=[e.kind], authors=[e.pubkey], **{"#d": [d_tag_value]}, # type: ignore until=e.created_at, ) await delete_events(self.relay_id, deletion_filter) if not e.is_ephemeral_event: await create_event(e) await self._broadcast_event(e) if e.is_delete_event: await self._handle_delete_event(e) resp_nip20 += [True, ""] except Exception as ex: logger.debug(ex) event = await get_event(self.relay_id, e.id) # todo: handle NIP20 in detail message = "error: failed to create event" resp_nip20 += [event is not None, message] await self._send_msg(resp_nip20) @property def config(self) -> RelaySpec: if not self.get_client_config: raise Exception("Client not ready!") return self.get_client_config() async def _send_msg(self, data: list): await self.websocket.send_text(json.dumps(data)) async def _handle_delete_event(self, event: NostrEvent): # NIP 09 - Handle both regular events (e tags) and parameterized replaceable events (a tags) # Get event IDs from 'e' tags (for regular events) event_ids = [t[1] for t in event.tags if t[0] == "e"] # Get event addresses from 'a' tags (for parameterized replaceable events) event_addresses = [t[1] for t in event.tags if t[0] == "a"] ids_to_delete = [] # Handle regular event deletions (e tags) if event_ids: nostr_filter = NostrFilter(authors=[event.pubkey], ids=event_ids) events_to_delete = await get_events(self.relay_id, nostr_filter, False) ids_to_delete.extend( [e.id for e in events_to_delete if not e.is_delete_event] ) # Handle parameterized replaceable event deletions (a tags) if event_addresses: for addr in event_addresses: # Parse address format: kind:pubkey:d-tag parts = addr.split(":") if len(parts) == 3: kind_str, addr_pubkey, d_tag = parts try: kind = int(kind_str) # Only delete if the address pubkey matches the deletion event author if addr_pubkey == event.pubkey: # NOTE: Use "#d" alias, not "d" directly (Pydantic Field alias) nostr_filter = NostrFilter( authors=[addr_pubkey], kinds=[kind], **{"#d": [d_tag]}, # Use alias to set d field ) events_to_delete = await get_events( self.relay_id, nostr_filter, False ) ids_to_delete.extend( [ e.id for e in events_to_delete if not e.is_delete_event ] ) else: logger.warning( f"Deletion request pubkey mismatch: {addr_pubkey} != {event.pubkey}" ) except ValueError: logger.warning(f"Invalid kind in address: {addr}") else: logger.warning( f"Invalid address format (expected kind:pubkey:d-tag): {addr}" ) # Only mark events as deleted if we found specific IDs if ids_to_delete: await mark_events_deleted(self.relay_id, NostrFilter(ids=ids_to_delete)) async def _handle_request( self, subscription_id: str, nostr_filter: NostrFilter ) -> list: if self.config.require_auth_filter: if not self.auth_pubkey: return [["AUTH", self._current_auth_challenge()]] account = await get_account(self.relay_id, self.auth_pubkey) if not account: account = NostrAccount.null_account() if account.blocked: return [ [ "NOTICE", ( f"Public key '{self.auth_pubkey}' is not allowed " f"in relay '{self.relay_id}'!" ), ] ] if not account.can_join and not self.config.is_free_to_join: return [["NOTICE", f"This is a paid relay: '{self.relay_id}'"]] nostr_filter.subscription_id = subscription_id self._remove_filter(subscription_id) if self._can_add_filter(): max_filters = self.config.max_client_filters return [ [ "NOTICE", f"Maximum number of filters ({max_filters}) exceeded.", ] ] nostr_filter.enforce_limit(self.config.limit_per_filter) self.filters.append(nostr_filter) events = await get_events(self.relay_id, nostr_filter) events = [e for e in events if not self._is_direct_message_for_other(e)] serialized_events = [ event.serialize_response(subscription_id) for event in events ] resp_nip15 = ["EOSE", subscription_id] serialized_events.append(resp_nip15) return serialized_events def _remove_filter(self, subscription_id: str): self.filters = [f for f in self.filters if f.subscription_id != subscription_id] def _handle_close(self, subscription_id: str): self._remove_filter(subscription_id) async def _handle_auth(self): await self._send_msg(["AUTH", self._current_auth_challenge()]) def _can_add_filter(self) -> bool: return ( self.config.max_client_filters != 0 and len(self.filters) >= self.config.max_client_filters ) def _auth_challenge_expired(self): if self._auth_challenge_created_at == 0: return True current_time_seconds = round(time.time()) chanllenge_max_age_seconds = 300 # 5 min return ( current_time_seconds - self._auth_challenge_created_at ) >= chanllenge_max_age_seconds def _current_auth_challenge(self): if self._auth_challenge_expired(): self._auth_challenge = self.relay_id + ":" + urlsafe_short_hash() self._auth_challenge_created_at = round(time.time()) return self._auth_challenge