diff --git a/client_manager.py b/client_manager.py index d0a2c0f..cb99b44 100644 --- a/client_manager.py +++ b/client_manager.py @@ -18,6 +18,7 @@ from .crud import ( mark_events_deleted, prune_old_events, ) +from .helpers import extract_domain from .models import NostrAccount, NostrEvent, NostrEventType, NostrFilter, RelaySpec @@ -166,6 +167,17 @@ class NostrClientConnection: async def _handle_event(self, e: NostrEvent): logger.info(f"nostr event: [{e.kind}, {e.pubkey}, '{e.content}']") resp_nip20: List[Any] = ["OK", e.id] + + if e.is_auth_response_event: + valid, message = self._validate_auth_event(e) + if not valid: + resp_nip20 += [valid, message] + await self._send_msg(resp_nip20) + return None + self.authenticated = True + return None + + if not self.authenticated and self.client_config.event_requires_auth(e.kind): await self._send_msg(["AUTH", self._current_auth_challenge()]) resp_nip20 += [False, "Relay requires authentication"] @@ -180,14 +192,14 @@ class NostrClientConnection: return None try: - if e.is_replaceable_event(): + if e.is_replaceable_event: await delete_events( self.relay_id, NostrFilter(kinds=[e.kind], authors=[e.pubkey]) ) await create_event(self.relay_id, e) await self._broadcast_event(e) - if e.is_delete_event(): + if e.is_delete_event: await self._handle_delete_event(e) resp_nip20 += [True, ""] except Exception as ex: @@ -213,7 +225,7 @@ class NostrClientConnection: filter = NostrFilter(authors=[event.pubkey]) filter.ids = [t[1] for t in event.tags if t[0] == "e"] events_to_delete = await get_events(self.relay_id, filter, False) - ids = [e.id for e in events_to_delete if not e.is_delete_event()] + ids = [e.id for e in events_to_delete if not e.is_delete_event] await mark_events_deleted(self.relay_id, NostrFilter(ids=ids)) async def _handle_request(self, subscription_id: str, filter: NostrFilter) -> List: @@ -255,14 +267,32 @@ class NostrClientConnection: and len(self.filters) >= self.client_config.max_client_filters ) - async def _validate_write(self, e: NostrEvent) -> Tuple[bool, str]: + def _validate_auth_event(self, e: NostrEvent) -> Tuple[bool, str]: valid, message = self._validate_event(e) if not valid: return [valid, message] + relay_tag = e.tag_values("relay") + challenge_tag = e.tag_values("challenge") + if len(relay_tag) == 0 or len(challenge_tag) == 0: + return False, "NIP42 tags are missing" + + if self.client_config.domain != extract_domain(relay_tag[0]): + return False, "Wrong relay domain" + + if self._auth_challenge != challenge_tag[0]: + return False, "Wrong chanlange value" + + return True, "" + + async def _validate_write(self, e: NostrEvent) -> Tuple[bool, str]: + valid, message = self._validate_event(e) + if not valid: + return (valid, message) + valid, message = await self._validate_storage(e.pubkey, e.size_bytes) if not valid: - return [valid, message] + return (valid, message) return True, "" diff --git a/helpers.py b/helpers.py index bcf5c02..8e0b15d 100644 --- a/helpers.py +++ b/helpers.py @@ -1,3 +1,5 @@ +from urllib.parse import urlparse + from bech32 import bech32_decode, convertbits @@ -17,3 +19,6 @@ def normalize_public_key(pubkey: str) -> str: raise ValueError("Public Key is not valid hex") int(pubkey, 16) return pubkey + +def extract_domain(url: str) -> str: + return urlparse(url).netloc \ No newline at end of file diff --git a/views_api.py b/views_api.py index df53d96..f44ca95 100644 --- a/views_api.py +++ b/views_api.py @@ -1,6 +1,5 @@ from http import HTTPStatus from typing import List, Optional -from urllib.parse import urlparse from fastapi import Depends, Request, WebSocket from fastapi.exceptions import HTTPException @@ -27,7 +26,7 @@ from .crud import ( get_relays, update_relay, ) -from .helpers import normalize_public_key +from .helpers import extract_domain, normalize_public_key from .models import BuyOrder, NostrRelay client_manager = NostrClientManager() @@ -57,7 +56,7 @@ async def api_create_relay( data.id = urlsafe_short_hash()[:8] try: - data.config.domain = urlparse(str(request.url)).netloc + data.config.domain = extract_domain(str(request.url)) relay = await create_relay(wallet.wallet.user, data) return relay