345 lines
13 KiB
Python
345 lines
13 KiB
Python
import json
|
|
import time
|
|
from collections.abc import Awaitable, Callable
|
|
from typing import Any
|
|
|
|
from fastapi import WebSocket
|
|
from lnbits.helpers import urlsafe_short_hash
|
|
from loguru import logger
|
|
|
|
from ..crud import (
|
|
NostrAccount,
|
|
create_event,
|
|
delete_events,
|
|
get_account,
|
|
get_event,
|
|
get_events,
|
|
mark_events_deleted,
|
|
)
|
|
from .event import NostrEvent, NostrEventType
|
|
from .event_validator import EventValidator
|
|
from .filter import NostrFilter
|
|
from .relay import RelaySpec
|
|
|
|
|
|
class NostrClientConnection:
|
|
def __init__(self, relay_id: str, websocket: WebSocket):
|
|
self.websocket = websocket
|
|
self.relay_id = relay_id
|
|
self.filters: list[NostrFilter] = []
|
|
self.auth_pubkey: str | None = None # set if authenticated
|
|
self._auth_challenge: str | None = None
|
|
self._auth_challenge_created_at = 0
|
|
|
|
self.event_validator = EventValidator(self.relay_id)
|
|
|
|
self.broadcast_event: (
|
|
Callable[[NostrClientConnection, NostrEvent], Awaitable[None]] | None
|
|
) = None
|
|
self.get_client_config: Callable[[], RelaySpec] | None = None
|
|
|
|
async def start(self):
|
|
await self.websocket.accept()
|
|
while True:
|
|
json_data = await self.websocket.receive_text()
|
|
try:
|
|
data = json.loads(json_data)
|
|
|
|
resp = await self._handle_message(data)
|
|
for r in resp:
|
|
await self._send_msg(r)
|
|
except Exception as e:
|
|
logger.warning(e)
|
|
|
|
async def stop(self, reason: str | None):
|
|
message = reason if reason else "Server closed webocket"
|
|
try:
|
|
await self._send_msg(["NOTICE", message])
|
|
except Exception:
|
|
pass
|
|
|
|
try:
|
|
await self.websocket.close(reason=reason)
|
|
except Exception:
|
|
pass
|
|
|
|
def init_callbacks(self, broadcast_event: Callable, get_client_config: Callable):
|
|
self.broadcast_event = broadcast_event
|
|
self.get_client_config = get_client_config
|
|
self.event_validator.get_client_config = get_client_config
|
|
|
|
async def notify_event(self, event: NostrEvent) -> bool:
|
|
if self._is_direct_message_for_other(event):
|
|
return False
|
|
|
|
for nostr_filter in self.filters:
|
|
if nostr_filter.matches(event):
|
|
resp = event.serialize_response(nostr_filter.subscription_id)
|
|
await self._send_msg(resp)
|
|
return True
|
|
return False
|
|
|
|
def _is_direct_message_for_other(self, event: NostrEvent) -> bool:
|
|
"""
|
|
Direct messages are not inteded to be boradcast (even if encrypted).
|
|
If the server requires AUTH for kind '4' then direct message will be
|
|
sent only to the intended client.
|
|
"""
|
|
if not event.is_direct_message:
|
|
return False
|
|
if not self.config.event_requires_auth(event.kind):
|
|
return False
|
|
if not self.auth_pubkey:
|
|
return True
|
|
if event.has_tag_value("p", self.auth_pubkey):
|
|
return False
|
|
return True
|
|
|
|
async def _broadcast_event(self, e: NostrEvent):
|
|
if self.broadcast_event:
|
|
await self.broadcast_event(self, e)
|
|
|
|
async def _handle_message(self, data: list) -> list:
|
|
if len(data) < 2:
|
|
return []
|
|
|
|
message_type = data[0]
|
|
|
|
if message_type == NostrEventType.EVENT:
|
|
event_dict = {
|
|
"relay_id": self.relay_id,
|
|
"publisher": data[1]["pubkey"],
|
|
**data[1],
|
|
}
|
|
|
|
event = NostrEvent(**event_dict)
|
|
await self._handle_event(event)
|
|
return []
|
|
if message_type == NostrEventType.REQ:
|
|
if len(data) < 3:
|
|
return []
|
|
subscription_id = data[1]
|
|
# Handle multiple filters in REQ message
|
|
responses = []
|
|
for filter_data in data[2:]:
|
|
response = await self._handle_request(
|
|
subscription_id, NostrFilter.parse_obj(filter_data)
|
|
)
|
|
responses.extend(response)
|
|
return responses
|
|
if message_type == NostrEventType.CLOSE:
|
|
self._handle_close(data[1])
|
|
if message_type == NostrEventType.AUTH:
|
|
await self._handle_auth()
|
|
|
|
return []
|
|
|
|
async def _handle_event(self, e: NostrEvent):
|
|
logger.info(f"nostr event: [{e.kind}, {e.pubkey}, '{e.content}']")
|
|
resp_nip20: list[Any] = ["OK", e.id]
|
|
|
|
if e.is_auth_response_event:
|
|
valid, message = self.event_validator.validate_auth_event(
|
|
e, self._auth_challenge
|
|
)
|
|
if not valid:
|
|
resp_nip20 += [valid, message]
|
|
await self._send_msg(resp_nip20)
|
|
return None
|
|
self.auth_pubkey = e.pubkey
|
|
|
|
if not self.auth_pubkey and self.config.event_requires_auth(e.kind):
|
|
await self._send_msg(["AUTH", self._current_auth_challenge()])
|
|
resp_nip20 += [
|
|
False,
|
|
f"Relay requires authentication for events of kind '{e.kind}'",
|
|
]
|
|
await self._send_msg(resp_nip20)
|
|
return None
|
|
|
|
publisher_pubkey = self.auth_pubkey if self.auth_pubkey else e.pubkey
|
|
valid, message = await self.event_validator.validate_write(e, publisher_pubkey)
|
|
if not valid:
|
|
resp_nip20 += [valid, message]
|
|
await self._send_msg(resp_nip20)
|
|
return None
|
|
try:
|
|
if e.is_replaceable_event:
|
|
await delete_events(
|
|
self.relay_id,
|
|
NostrFilter(kinds=[e.kind], authors=[e.pubkey], until=e.created_at),
|
|
)
|
|
if e.is_addressable_event:
|
|
# Extract 'd' tag value for addressable replacement (NIP-01)
|
|
d_tag_value = next((t[1] for t in e.tags if t[0] == "d"), None)
|
|
|
|
if d_tag_value:
|
|
deletion_filter = NostrFilter(
|
|
kinds=[e.kind],
|
|
authors=[e.pubkey],
|
|
**{"#d": [d_tag_value]}, # type: ignore
|
|
until=e.created_at,
|
|
)
|
|
|
|
await delete_events(self.relay_id, deletion_filter)
|
|
if not e.is_ephemeral_event:
|
|
await create_event(e)
|
|
await self._broadcast_event(e)
|
|
|
|
if e.is_delete_event:
|
|
await self._handle_delete_event(e)
|
|
resp_nip20 += [True, ""]
|
|
except Exception as ex:
|
|
logger.debug(ex)
|
|
event = await get_event(self.relay_id, e.id)
|
|
# todo: handle NIP20 in detail
|
|
message = "error: failed to create event"
|
|
resp_nip20 += [event is not None, message]
|
|
|
|
await self._send_msg(resp_nip20)
|
|
|
|
@property
|
|
def config(self) -> RelaySpec:
|
|
if not self.get_client_config:
|
|
raise Exception("Client not ready!")
|
|
return self.get_client_config()
|
|
|
|
async def _send_msg(self, data: list):
|
|
await self.websocket.send_text(json.dumps(data))
|
|
|
|
async def _handle_delete_event(self, event: NostrEvent):
|
|
# NIP 09 - Handle both regular events (e tags) and parameterized replaceable events (a tags)
|
|
|
|
# Get event IDs from 'e' tags (for regular events)
|
|
event_ids = [t[1] for t in event.tags if t[0] == "e"]
|
|
|
|
# Get event addresses from 'a' tags (for parameterized replaceable events)
|
|
event_addresses = [t[1] for t in event.tags if t[0] == "a"]
|
|
|
|
ids_to_delete = []
|
|
|
|
# Handle regular event deletions (e tags)
|
|
if event_ids:
|
|
nostr_filter = NostrFilter(authors=[event.pubkey], ids=event_ids)
|
|
events_to_delete = await get_events(self.relay_id, nostr_filter, False)
|
|
ids_to_delete.extend(
|
|
[e.id for e in events_to_delete if not e.is_delete_event]
|
|
)
|
|
|
|
# Handle parameterized replaceable event deletions (a tags)
|
|
if event_addresses:
|
|
for addr in event_addresses:
|
|
# Parse address format: kind:pubkey:d-tag
|
|
parts = addr.split(":")
|
|
if len(parts) == 3:
|
|
kind_str, addr_pubkey, d_tag = parts
|
|
try:
|
|
kind = int(kind_str)
|
|
# Only delete if the address pubkey matches the deletion event author
|
|
if addr_pubkey == event.pubkey:
|
|
# NOTE: Use "#d" alias, not "d" directly (Pydantic Field alias)
|
|
nostr_filter = NostrFilter(
|
|
authors=[addr_pubkey],
|
|
kinds=[kind],
|
|
**{"#d": [d_tag]}, # Use alias to set d field
|
|
)
|
|
events_to_delete = await get_events(
|
|
self.relay_id, nostr_filter, False
|
|
)
|
|
ids_to_delete.extend(
|
|
[
|
|
e.id
|
|
for e in events_to_delete
|
|
if not e.is_delete_event
|
|
]
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f"Deletion request pubkey mismatch: {addr_pubkey} != {event.pubkey}"
|
|
)
|
|
except ValueError:
|
|
logger.warning(f"Invalid kind in address: {addr}")
|
|
else:
|
|
logger.warning(
|
|
f"Invalid address format (expected kind:pubkey:d-tag): {addr}"
|
|
)
|
|
|
|
# Only mark events as deleted if we found specific IDs
|
|
if ids_to_delete:
|
|
await mark_events_deleted(self.relay_id, NostrFilter(ids=ids_to_delete))
|
|
|
|
async def _handle_request(
|
|
self, subscription_id: str, nostr_filter: NostrFilter
|
|
) -> list:
|
|
if self.config.require_auth_filter:
|
|
if not self.auth_pubkey:
|
|
return [["AUTH", self._current_auth_challenge()]]
|
|
account = await get_account(self.relay_id, self.auth_pubkey)
|
|
if not account:
|
|
account = NostrAccount.null_account()
|
|
|
|
if account.blocked:
|
|
return [
|
|
[
|
|
"NOTICE",
|
|
(
|
|
f"Public key '{self.auth_pubkey}' is not allowed "
|
|
f"in relay '{self.relay_id}'!"
|
|
),
|
|
]
|
|
]
|
|
|
|
if not account.can_join and not self.config.is_free_to_join:
|
|
return [["NOTICE", f"This is a paid relay: '{self.relay_id}'"]]
|
|
|
|
nostr_filter.subscription_id = subscription_id
|
|
self._remove_filter(subscription_id)
|
|
if self._can_add_filter():
|
|
max_filters = self.config.max_client_filters
|
|
return [
|
|
[
|
|
"NOTICE",
|
|
f"Maximum number of filters ({max_filters}) exceeded.",
|
|
]
|
|
]
|
|
|
|
nostr_filter.enforce_limit(self.config.limit_per_filter)
|
|
self.filters.append(nostr_filter)
|
|
events = await get_events(self.relay_id, nostr_filter)
|
|
events = [e for e in events if not self._is_direct_message_for_other(e)]
|
|
serialized_events = [
|
|
event.serialize_response(subscription_id) for event in events
|
|
]
|
|
resp_nip15 = ["EOSE", subscription_id]
|
|
serialized_events.append(resp_nip15)
|
|
return serialized_events
|
|
|
|
def _remove_filter(self, subscription_id: str):
|
|
self.filters = [f for f in self.filters if f.subscription_id != subscription_id]
|
|
|
|
def _handle_close(self, subscription_id: str):
|
|
self._remove_filter(subscription_id)
|
|
|
|
async def _handle_auth(self):
|
|
await self._send_msg(["AUTH", self._current_auth_challenge()])
|
|
|
|
def _can_add_filter(self) -> bool:
|
|
return (
|
|
self.config.max_client_filters != 0
|
|
and len(self.filters) >= self.config.max_client_filters
|
|
)
|
|
|
|
def _auth_challenge_expired(self):
|
|
if self._auth_challenge_created_at == 0:
|
|
return True
|
|
current_time_seconds = round(time.time())
|
|
chanllenge_max_age_seconds = 300 # 5 min
|
|
return (
|
|
current_time_seconds - self._auth_challenge_created_at
|
|
) >= chanllenge_max_age_seconds
|
|
|
|
def _current_auth_challenge(self):
|
|
if self._auth_challenge_expired():
|
|
self._auth_challenge = self.relay_id + ":" + urlsafe_short_hash()
|
|
self._auth_challenge_created_at = round(time.time())
|
|
return self._auth_challenge
|