refactor: extract extract_domain function
This commit is contained in:
parent
3648dc212c
commit
1c81ca300f
3 changed files with 42 additions and 8 deletions
|
|
@ -18,6 +18,7 @@ from .crud import (
|
|||
mark_events_deleted,
|
||||
prune_old_events,
|
||||
)
|
||||
from .helpers import extract_domain
|
||||
from .models import NostrAccount, NostrEvent, NostrEventType, NostrFilter, RelaySpec
|
||||
|
||||
|
||||
|
|
@ -166,6 +167,17 @@ class NostrClientConnection:
|
|||
async def _handle_event(self, e: NostrEvent):
|
||||
logger.info(f"nostr event: [{e.kind}, {e.pubkey}, '{e.content}']")
|
||||
resp_nip20: List[Any] = ["OK", e.id]
|
||||
|
||||
if e.is_auth_response_event:
|
||||
valid, message = self._validate_auth_event(e)
|
||||
if not valid:
|
||||
resp_nip20 += [valid, message]
|
||||
await self._send_msg(resp_nip20)
|
||||
return None
|
||||
self.authenticated = True
|
||||
return None
|
||||
|
||||
|
||||
if not self.authenticated and self.client_config.event_requires_auth(e.kind):
|
||||
await self._send_msg(["AUTH", self._current_auth_challenge()])
|
||||
resp_nip20 += [False, "Relay requires authentication"]
|
||||
|
|
@ -180,14 +192,14 @@ class NostrClientConnection:
|
|||
return None
|
||||
|
||||
try:
|
||||
if e.is_replaceable_event():
|
||||
if e.is_replaceable_event:
|
||||
await delete_events(
|
||||
self.relay_id, NostrFilter(kinds=[e.kind], authors=[e.pubkey])
|
||||
)
|
||||
await create_event(self.relay_id, e)
|
||||
await self._broadcast_event(e)
|
||||
|
||||
if e.is_delete_event():
|
||||
if e.is_delete_event:
|
||||
await self._handle_delete_event(e)
|
||||
resp_nip20 += [True, ""]
|
||||
except Exception as ex:
|
||||
|
|
@ -213,7 +225,7 @@ class NostrClientConnection:
|
|||
filter = NostrFilter(authors=[event.pubkey])
|
||||
filter.ids = [t[1] for t in event.tags if t[0] == "e"]
|
||||
events_to_delete = await get_events(self.relay_id, filter, False)
|
||||
ids = [e.id for e in events_to_delete if not e.is_delete_event()]
|
||||
ids = [e.id for e in events_to_delete if not e.is_delete_event]
|
||||
await mark_events_deleted(self.relay_id, NostrFilter(ids=ids))
|
||||
|
||||
async def _handle_request(self, subscription_id: str, filter: NostrFilter) -> List:
|
||||
|
|
@ -255,14 +267,32 @@ class NostrClientConnection:
|
|||
and len(self.filters) >= self.client_config.max_client_filters
|
||||
)
|
||||
|
||||
async def _validate_write(self, e: NostrEvent) -> Tuple[bool, str]:
|
||||
def _validate_auth_event(self, e: NostrEvent) -> Tuple[bool, str]:
|
||||
valid, message = self._validate_event(e)
|
||||
if not valid:
|
||||
return [valid, message]
|
||||
|
||||
relay_tag = e.tag_values("relay")
|
||||
challenge_tag = e.tag_values("challenge")
|
||||
if len(relay_tag) == 0 or len(challenge_tag) == 0:
|
||||
return False, "NIP42 tags are missing"
|
||||
|
||||
if self.client_config.domain != extract_domain(relay_tag[0]):
|
||||
return False, "Wrong relay domain"
|
||||
|
||||
if self._auth_challenge != challenge_tag[0]:
|
||||
return False, "Wrong chanlange value"
|
||||
|
||||
return True, ""
|
||||
|
||||
async def _validate_write(self, e: NostrEvent) -> Tuple[bool, str]:
|
||||
valid, message = self._validate_event(e)
|
||||
if not valid:
|
||||
return (valid, message)
|
||||
|
||||
valid, message = await self._validate_storage(e.pubkey, e.size_bytes)
|
||||
if not valid:
|
||||
return [valid, message]
|
||||
return (valid, message)
|
||||
|
||||
return True, ""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from urllib.parse import urlparse
|
||||
|
||||
from bech32 import bech32_decode, convertbits
|
||||
|
||||
|
||||
|
|
@ -17,3 +19,6 @@ def normalize_public_key(pubkey: str) -> str:
|
|||
raise ValueError("Public Key is not valid hex")
|
||||
int(pubkey, 16)
|
||||
return pubkey
|
||||
|
||||
def extract_domain(url: str) -> str:
|
||||
return urlparse(url).netloc
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
from http import HTTPStatus
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import Depends, Request, WebSocket
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
|
@ -27,7 +26,7 @@ from .crud import (
|
|||
get_relays,
|
||||
update_relay,
|
||||
)
|
||||
from .helpers import normalize_public_key
|
||||
from .helpers import extract_domain, normalize_public_key
|
||||
from .models import BuyOrder, NostrRelay
|
||||
|
||||
client_manager = NostrClientManager()
|
||||
|
|
@ -57,7 +56,7 @@ async def api_create_relay(
|
|||
data.id = urlsafe_short_hash()[:8]
|
||||
|
||||
try:
|
||||
data.config.domain = urlparse(str(request.url)).netloc
|
||||
data.config.domain = extract_domain(str(request.url))
|
||||
relay = await create_relay(wallet.wallet.user, data)
|
||||
return relay
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue