refactor: extract client_connection

This commit is contained in:
Vlad Stan 2023-02-17 14:05:50 +02:00
parent aa68d2a79a
commit c42d81f696
4 changed files with 249 additions and 247 deletions

0
relay/__init__.py Normal file
View file

237
relay/client_connection.py Normal file
View file

@ -0,0 +1,237 @@
import json
import time
from typing import Any, Awaitable, Callable, List, Optional
from fastapi import WebSocket
from loguru import logger
from lnbits.helpers import urlsafe_short_hash
from ..crud import (
create_event,
delete_events,
get_event,
get_events,
mark_events_deleted,
)
from ..models import NostrEvent, NostrEventType, NostrFilter, RelaySpec
from .event_validator import EventValidator
class NostrClientConnection:
def __init__(self, relay_id: str, websocket: WebSocket):
self.websocket = websocket
self.relay_id = relay_id
self.filters: List[NostrFilter] = []
self.pubkey: Optional[str] = None # set if authenticated
self._auth_challenge: Optional[str] = None
self._auth_challenge_created_at = 0
self.event_validator = EventValidator(self.relay_id)
self.broadcast_event: Optional[
Callable[[NostrClientConnection, NostrEvent], Awaitable[None]]
] = None
self.get_client_config: Optional[Callable[[], RelaySpec]] = None
async def start(self):
await self.websocket.accept()
while True:
json_data = await self.websocket.receive_text()
try:
data = json.loads(json_data)
resp = await self._handle_message(data)
for r in resp:
await self._send_msg(r)
except Exception as e:
logger.warning(e)
async def stop(self, reason: Optional[str]):
message = reason if reason else "Server closed webocket"
try:
await self._send_msg(["NOTICE", message])
except:
pass
try:
await self.websocket.close(reason=reason)
except:
pass
def init_callbacks(self, broadcast_event: Callable, get_client_config: Callable):
setattr(self, "broadcast_event", broadcast_event)
setattr(self, "get_client_config", get_client_config)
setattr(self.event_validator, "get_client_config", get_client_config)
async def notify_event(self, event: NostrEvent) -> bool:
if self._is_direct_message_for_other(event):
return False
for filter in self.filters:
if filter.matches(event):
resp = event.serialize_response(filter.subscription_id)
await self._send_msg(resp)
return True
return False
def _is_direct_message_for_other(self, event: NostrEvent) -> bool:
"""
Direct messages are not inteded to be boradcast (even if encrypted).
If the server requires AUTH for kind '4' then direct message will be sent only to the intended client.
"""
if not event.is_direct_message:
return False
if not self.config.event_requires_auth(event.kind):
return False
if not self.pubkey:
return True
if event.has_tag_value("p", self.pubkey):
return False
return True
async def _broadcast_event(self, e: NostrEvent):
if self.broadcast_event:
await self.broadcast_event(self, e)
async def _handle_message(self, data: List) -> List:
if len(data) < 2:
return []
message_type = data[0]
if message_type == NostrEventType.EVENT:
await self._handle_event(NostrEvent.parse_obj(data[1]))
return []
if message_type == NostrEventType.REQ:
if len(data) != 3:
return []
return await self._handle_request(data[1], NostrFilter.parse_obj(data[2]))
if message_type == NostrEventType.CLOSE:
self._handle_close(data[1])
if message_type == NostrEventType.AUTH:
await self._handle_auth()
return []
async def _handle_event(self, e: NostrEvent):
logger.info(f"nostr event: [{e.kind}, {e.pubkey}, '{e.content}']")
resp_nip20: List[Any] = ["OK", e.id]
if e.is_auth_response_event:
valid, message = self.event_validator.validate_auth_event(e, self._auth_challenge)
if not valid:
resp_nip20 += [valid, message]
await self._send_msg(resp_nip20)
return None
self.pubkey = e.pubkey
return None
if not self.pubkey and self.config.event_requires_auth(e.kind):
await self._send_msg(["AUTH", self._current_auth_challenge()])
resp_nip20 += [
False,
f"restricted: Relay requires authentication for events of kind '{e.kind}'",
]
await self._send_msg(resp_nip20)
return None
publisher_pubkey = self.pubkey if self.pubkey else e.pubkey
valid, message = await self.event_validator.validate_write(e, publisher_pubkey)
if not valid:
resp_nip20 += [valid, message]
await self._send_msg(resp_nip20)
return None
try:
if e.is_replaceable_event:
await delete_events(
self.relay_id, NostrFilter(kinds=[e.kind], authors=[e.pubkey], until=e.created_at)
)
if not e.is_ephemeral_event:
await create_event(self.relay_id, e, self.pubkey)
await self._broadcast_event(e)
if e.is_delete_event:
await self._handle_delete_event(e)
resp_nip20 += [True, ""]
except Exception as ex:
logger.debug(ex)
event = await get_event(self.relay_id, e.id)
# todo: handle NIP20 in detail
message = "error: failed to create event"
resp_nip20 += [event != None, message]
await self._send_msg(resp_nip20)
@property
def config(self) -> RelaySpec:
if not self.get_client_config:
raise Exception("Client not ready!")
return self.get_client_config()
async def _send_msg(self, data: List):
await self.websocket.send_text(json.dumps(data))
async def _handle_delete_event(self, event: NostrEvent):
# NIP 09
filter = NostrFilter(authors=[event.pubkey])
filter.ids = [t[1] for t in event.tags if t[0] == "e"]
events_to_delete = await get_events(self.relay_id, filter, False)
ids = [e.id for e in events_to_delete if not e.is_delete_event]
await mark_events_deleted(self.relay_id, NostrFilter(ids=ids))
async def _handle_request(self, subscription_id: str, filter: NostrFilter) -> List:
if not self.pubkey and self.config.require_auth_filter:
return [["AUTH", self._current_auth_challenge()]]
filter.subscription_id = subscription_id
self._remove_filter(subscription_id)
if self._can_add_filter():
return [
[
"NOTICE",
f"Maximum number of filters ({self.config.max_client_filters}) exceeded.",
]
]
filter.enforce_limit(self.config.limit_per_filter)
self.filters.append(filter)
events = await get_events(self.relay_id, filter)
events = [e for e in events if not self._is_direct_message_for_other(e)]
serialized_events = [
event.serialize_response(subscription_id) for event in events
]
resp_nip15 = ["EOSE", subscription_id]
serialized_events.append(resp_nip15)
return serialized_events
def _remove_filter(self, subscription_id: str):
self.filters = [f for f in self.filters if f.subscription_id != subscription_id]
def _handle_close(self, subscription_id: str):
self._remove_filter(subscription_id)
async def _handle_auth(self):
await self._send_msg(["AUTH", self._current_auth_challenge()])
def _can_add_filter(self) -> bool:
return (
self.config.max_client_filters != 0
and len(self.filters) >= self.config.max_client_filters
)
def _auth_challenge_expired(self):
if self._auth_challenge_created_at == 0:
return True
current_time_seconds = round(time.time())
chanllenge_max_age_seconds = 300 # 5 min
return (
current_time_seconds - self._auth_challenge_created_at
) >= chanllenge_max_age_seconds
def _current_auth_challenge(self):
if self._auth_challenge_expired():
self._auth_challenge = self.relay_id + ":" + urlsafe_short_hash()
self._auth_challenge_created_at = round(time.time())
return self._auth_challenge

View file

@ -1,22 +1,9 @@
import json from typing import List
import time
from typing import Any, Awaitable, Callable, List, Optional
from fastapi import WebSocket from .client_connection import NostrClientConnection
from loguru import logger
from lnbits.helpers import urlsafe_short_hash from ..crud import get_config_for_all_active_relays
from ..models import NostrEvent, RelaySpec
from ..crud import (
create_event,
delete_events,
get_config_for_all_active_relays,
get_event,
get_events,
mark_events_deleted,
)
from ..models import NostrEvent, NostrEventType, NostrFilter, RelaySpec
from .event_validator import EventValidator
class NostrClientManager: class NostrClientManager:
@ -25,7 +12,7 @@ class NostrClientManager:
self._active_relays: dict = {} self._active_relays: dict = {}
self._is_ready = False self._is_ready = False
async def add_client(self, c: "NostrClientConnection") -> bool: async def add_client(self, c: NostrClientConnection) -> bool:
if not self._is_ready: if not self._is_ready:
await self.init_relays() await self.init_relays()
@ -37,10 +24,10 @@ class NostrClientManager:
return True return True
def remove_client(self, c: "NostrClientConnection"): def remove_client(self, c: NostrClientConnection):
self.clients(c.relay_id).remove(c) self.clients(c.relay_id).remove(c)
async def broadcast_event(self, source: "NostrClientConnection", event: NostrEvent): async def broadcast_event(self, source: NostrClientConnection, event: NostrEvent):
for client in self.clients(source.relay_id): for client in self.clients(source.relay_id):
await client.notify_event(event) await client.notify_event(event)
@ -60,7 +47,7 @@ class NostrClientManager:
def get_relay_config(self, relay_id: str) -> RelaySpec: def get_relay_config(self, relay_id: str) -> RelaySpec:
return self._active_relays[relay_id] return self._active_relays[relay_id]
def clients(self, relay_id: str) -> List["NostrClientConnection"]: def clients(self, relay_id: str) -> List[NostrClientConnection]:
if relay_id not in self._clients: if relay_id not in self._clients:
self._clients[relay_id] = [] self._clients[relay_id] = []
return self._clients[relay_id] return self._clients[relay_id]
@ -70,236 +57,16 @@ class NostrClientManager:
if client.relay_id == relay_id: if client.relay_id == relay_id:
await client.stop(reason=f"Relay '{relay_id}' has been deactivated.") await client.stop(reason=f"Relay '{relay_id}' has been deactivated.")
async def _allow_client(self, c: "NostrClientConnection") -> bool: async def _allow_client(self, c: NostrClientConnection) -> bool:
if c.relay_id not in self._active_relays: if c.relay_id not in self._active_relays:
await c.stop(reason=f"Relay '{c.relay_id}' is not active") await c.stop(reason=f"Relay '{c.relay_id}' is not active")
return False return False
return True return True
def _set_client_callbacks(self, client: "NostrClientConnection"): def _set_client_callbacks(self, client: NostrClientConnection):
def get_client_config() -> RelaySpec: def get_client_config() -> RelaySpec:
return self.get_relay_config(client.relay_id) return self.get_relay_config(client.relay_id)
setattr(client, "get_client_config", get_client_config) setattr(client, "get_client_config", get_client_config)
client.init_callbacks(self.broadcast_event, get_client_config) client.init_callbacks(self.broadcast_event, get_client_config)
class NostrClientConnection:
def __init__(self, relay_id: str, websocket: WebSocket):
self.websocket = websocket
self.relay_id = relay_id
self.filters: List[NostrFilter] = []
self.pubkey: Optional[str] = None # set if authenticated
self._auth_challenge: Optional[str] = None
self._auth_challenge_created_at = 0
self.event_validator = EventValidator(self.relay_id)
self.broadcast_event: Optional[
Callable[[NostrClientConnection, NostrEvent], Awaitable[None]]
] = None
self.get_client_config: Optional[Callable[[], RelaySpec]] = None
async def start(self):
await self.websocket.accept()
while True:
json_data = await self.websocket.receive_text()
try:
data = json.loads(json_data)
resp = await self._handle_message(data)
for r in resp:
await self._send_msg(r)
except Exception as e:
logger.warning(e)
async def stop(self, reason: Optional[str]):
message = reason if reason else "Server closed webocket"
try:
await self._send_msg(["NOTICE", message])
except:
pass
try:
await self.websocket.close(reason=reason)
except:
pass
def init_callbacks(self, broadcast_event: Callable, get_client_config: Callable):
setattr(self, "broadcast_event", broadcast_event)
setattr(self, "get_client_config", get_client_config)
setattr(self.event_validator, "get_client_config", get_client_config)
async def notify_event(self, event: NostrEvent) -> bool:
if self._is_direct_message_for_other(event):
return False
for filter in self.filters:
if filter.matches(event):
resp = event.serialize_response(filter.subscription_id)
await self._send_msg(resp)
return True
return False
def _is_direct_message_for_other(self, event: NostrEvent) -> bool:
"""
Direct messages are not inteded to be boradcast (even if encrypted).
If the server requires AUTH for kind '4' then direct message will be sent only to the intended client.
"""
if not event.is_direct_message:
return False
if not self.config.event_requires_auth(event.kind):
return False
if not self.pubkey:
return True
if event.has_tag_value("p", self.pubkey):
return False
return True
async def _broadcast_event(self, e: NostrEvent):
if self.broadcast_event:
await self.broadcast_event(self, e)
async def _handle_message(self, data: List) -> List:
if len(data) < 2:
return []
message_type = data[0]
if message_type == NostrEventType.EVENT:
await self._handle_event(NostrEvent.parse_obj(data[1]))
return []
if message_type == NostrEventType.REQ:
if len(data) != 3:
return []
return await self._handle_request(data[1], NostrFilter.parse_obj(data[2]))
if message_type == NostrEventType.CLOSE:
self._handle_close(data[1])
if message_type == NostrEventType.AUTH:
await self._handle_auth()
return []
async def _handle_event(self, e: NostrEvent):
logger.info(f"nostr event: [{e.kind}, {e.pubkey}, '{e.content}']")
resp_nip20: List[Any] = ["OK", e.id]
if e.is_auth_response_event:
valid, message = self.event_validator.validate_auth_event(e, self._auth_challenge)
if not valid:
resp_nip20 += [valid, message]
await self._send_msg(resp_nip20)
return None
self.pubkey = e.pubkey
return None
if not self.pubkey and self.config.event_requires_auth(e.kind):
await self._send_msg(["AUTH", self._current_auth_challenge()])
resp_nip20 += [
False,
f"restricted: Relay requires authentication for events of kind '{e.kind}'",
]
await self._send_msg(resp_nip20)
return None
publisher_pubkey = self.pubkey if self.pubkey else e.pubkey
valid, message = await self.event_validator.validate_write(e, publisher_pubkey)
if not valid:
resp_nip20 += [valid, message]
await self._send_msg(resp_nip20)
return None
try:
if e.is_replaceable_event:
await delete_events(
self.relay_id, NostrFilter(kinds=[e.kind], authors=[e.pubkey], until=e.created_at)
)
if not e.is_ephemeral_event:
await create_event(self.relay_id, e, self.pubkey)
await self._broadcast_event(e)
if e.is_delete_event:
await self._handle_delete_event(e)
resp_nip20 += [True, ""]
except Exception as ex:
logger.debug(ex)
event = await get_event(self.relay_id, e.id)
# todo: handle NIP20 in detail
message = "error: failed to create event"
resp_nip20 += [event != None, message]
await self._send_msg(resp_nip20)
@property
def config(self) -> RelaySpec:
if not self.get_client_config:
raise Exception("Client not ready!")
return self.get_client_config()
async def _send_msg(self, data: List):
await self.websocket.send_text(json.dumps(data))
async def _handle_delete_event(self, event: NostrEvent):
# NIP 09
filter = NostrFilter(authors=[event.pubkey])
filter.ids = [t[1] for t in event.tags if t[0] == "e"]
events_to_delete = await get_events(self.relay_id, filter, False)
ids = [e.id for e in events_to_delete if not e.is_delete_event]
await mark_events_deleted(self.relay_id, NostrFilter(ids=ids))
async def _handle_request(self, subscription_id: str, filter: NostrFilter) -> List:
if not self.pubkey and self.config.require_auth_filter:
return [["AUTH", self._current_auth_challenge()]]
filter.subscription_id = subscription_id
self._remove_filter(subscription_id)
if self._can_add_filter():
return [
[
"NOTICE",
f"Maximum number of filters ({self.config.max_client_filters}) exceeded.",
]
]
filter.enforce_limit(self.config.limit_per_filter)
self.filters.append(filter)
events = await get_events(self.relay_id, filter)
events = [e for e in events if not self._is_direct_message_for_other(e)]
serialized_events = [
event.serialize_response(subscription_id) for event in events
]
resp_nip15 = ["EOSE", subscription_id]
serialized_events.append(resp_nip15)
return serialized_events
def _remove_filter(self, subscription_id: str):
self.filters = [f for f in self.filters if f.subscription_id != subscription_id]
def _handle_close(self, subscription_id: str):
self._remove_filter(subscription_id)
async def _handle_auth(self):
await self._send_msg(["AUTH", self._current_auth_challenge()])
def _can_add_filter(self) -> bool:
return (
self.config.max_client_filters != 0
and len(self.filters) >= self.config.max_client_filters
)
def _auth_challenge_expired(self):
if self._auth_challenge_created_at == 0:
return True
current_time_seconds = round(time.time())
chanllenge_max_age_seconds = 300 # 5 min
return (
current_time_seconds - self._auth_challenge_created_at
) >= chanllenge_max_age_seconds
def _current_auth_challenge(self):
if self._auth_challenge_expired():
self._auth_challenge = self.relay_id + ":" + urlsafe_short_hash()
self._auth_challenge_created_at = round(time.time())
return self._auth_challenge

View file

@ -7,10 +7,8 @@ from fastapi import WebSocket
from loguru import logger from loguru import logger
from lnbits.extensions.nostrrelay.models import RelaySpec # type: ignore from lnbits.extensions.nostrrelay.models import RelaySpec # type: ignore
from lnbits.extensions.nostrrelay.relay.client_manager import ( # type: ignore from lnbits.extensions.nostrrelay.relay.client_manager import NostrClientManager # type: ignore
NostrClientConnection, from lnbits.extensions.nostrrelay.relay.client_connection import NostrClientConnection # type: ignore
NostrClientManager,
)
from .helpers import get_fixtures from .helpers import get_fixtures