refactor: extract extract_domain function

This commit is contained in:
Vlad Stan 2023-02-15 10:33:56 +02:00
parent 3648dc212c
commit 1c81ca300f
3 changed files with 42 additions and 8 deletions

View file

@ -18,6 +18,7 @@ from .crud import (
mark_events_deleted,
prune_old_events,
)
from .helpers import extract_domain
from .models import NostrAccount, NostrEvent, NostrEventType, NostrFilter, RelaySpec
@ -166,6 +167,17 @@ class NostrClientConnection:
async def _handle_event(self, e: NostrEvent):
logger.info(f"nostr event: [{e.kind}, {e.pubkey}, '{e.content}']")
resp_nip20: List[Any] = ["OK", e.id]
if e.is_auth_response_event:
valid, message = self._validate_auth_event(e)
if not valid:
resp_nip20 += [valid, message]
await self._send_msg(resp_nip20)
return None
self.authenticated = True
return None
if not self.authenticated and self.client_config.event_requires_auth(e.kind):
await self._send_msg(["AUTH", self._current_auth_challenge()])
resp_nip20 += [False, "Relay requires authentication"]
@ -180,14 +192,14 @@ class NostrClientConnection:
return None
try:
if e.is_replaceable_event():
if e.is_replaceable_event:
await delete_events(
self.relay_id, NostrFilter(kinds=[e.kind], authors=[e.pubkey])
)
await create_event(self.relay_id, e)
await self._broadcast_event(e)
if e.is_delete_event():
if e.is_delete_event:
await self._handle_delete_event(e)
resp_nip20 += [True, ""]
except Exception as ex:
@ -213,7 +225,7 @@ class NostrClientConnection:
filter = NostrFilter(authors=[event.pubkey])
filter.ids = [t[1] for t in event.tags if t[0] == "e"]
events_to_delete = await get_events(self.relay_id, filter, False)
ids = [e.id for e in events_to_delete if not e.is_delete_event()]
ids = [e.id for e in events_to_delete if not e.is_delete_event]
await mark_events_deleted(self.relay_id, NostrFilter(ids=ids))
async def _handle_request(self, subscription_id: str, filter: NostrFilter) -> List:
@ -255,14 +267,32 @@ class NostrClientConnection:
and len(self.filters) >= self.client_config.max_client_filters
)
async def _validate_write(self, e: NostrEvent) -> Tuple[bool, str]:
def _validate_auth_event(self, e: NostrEvent) -> Tuple[bool, str]:
valid, message = self._validate_event(e)
if not valid:
return [valid, message]
relay_tag = e.tag_values("relay")
challenge_tag = e.tag_values("challenge")
if len(relay_tag) == 0 or len(challenge_tag) == 0:
return False, "NIP42 tags are missing"
if self.client_config.domain != extract_domain(relay_tag[0]):
return False, "Wrong relay domain"
if self._auth_challenge != challenge_tag[0]:
return False, "Wrong chanlange value"
return True, ""
async def _validate_write(self, e: NostrEvent) -> Tuple[bool, str]:
valid, message = self._validate_event(e)
if not valid:
return (valid, message)
valid, message = await self._validate_storage(e.pubkey, e.size_bytes)
if not valid:
return [valid, message]
return (valid, message)
return True, ""