refactor: extract extract_domain function

This commit is contained in:
Vlad Stan 2023-02-15 10:33:56 +02:00
parent 3648dc212c
commit 1c81ca300f
3 changed files with 42 additions and 8 deletions

View file

@ -18,6 +18,7 @@ from .crud import (
mark_events_deleted, mark_events_deleted,
prune_old_events, prune_old_events,
) )
from .helpers import extract_domain
from .models import NostrAccount, NostrEvent, NostrEventType, NostrFilter, RelaySpec from .models import NostrAccount, NostrEvent, NostrEventType, NostrFilter, RelaySpec
@ -166,6 +167,17 @@ class NostrClientConnection:
async def _handle_event(self, e: NostrEvent): async def _handle_event(self, e: NostrEvent):
logger.info(f"nostr event: [{e.kind}, {e.pubkey}, '{e.content}']") logger.info(f"nostr event: [{e.kind}, {e.pubkey}, '{e.content}']")
resp_nip20: List[Any] = ["OK", e.id] resp_nip20: List[Any] = ["OK", e.id]
if e.is_auth_response_event:
valid, message = self._validate_auth_event(e)
if not valid:
resp_nip20 += [valid, message]
await self._send_msg(resp_nip20)
return None
self.authenticated = True
return None
if not self.authenticated and self.client_config.event_requires_auth(e.kind): if not self.authenticated and self.client_config.event_requires_auth(e.kind):
await self._send_msg(["AUTH", self._current_auth_challenge()]) await self._send_msg(["AUTH", self._current_auth_challenge()])
resp_nip20 += [False, "Relay requires authentication"] resp_nip20 += [False, "Relay requires authentication"]
@ -180,14 +192,14 @@ class NostrClientConnection:
return None return None
try: try:
if e.is_replaceable_event(): if e.is_replaceable_event:
await delete_events( await delete_events(
self.relay_id, NostrFilter(kinds=[e.kind], authors=[e.pubkey]) self.relay_id, NostrFilter(kinds=[e.kind], authors=[e.pubkey])
) )
await create_event(self.relay_id, e) await create_event(self.relay_id, e)
await self._broadcast_event(e) await self._broadcast_event(e)
if e.is_delete_event(): if e.is_delete_event:
await self._handle_delete_event(e) await self._handle_delete_event(e)
resp_nip20 += [True, ""] resp_nip20 += [True, ""]
except Exception as ex: except Exception as ex:
@ -213,7 +225,7 @@ class NostrClientConnection:
filter = NostrFilter(authors=[event.pubkey]) filter = NostrFilter(authors=[event.pubkey])
filter.ids = [t[1] for t in event.tags if t[0] == "e"] filter.ids = [t[1] for t in event.tags if t[0] == "e"]
events_to_delete = await get_events(self.relay_id, filter, False) events_to_delete = await get_events(self.relay_id, filter, False)
ids = [e.id for e in events_to_delete if not e.is_delete_event()] ids = [e.id for e in events_to_delete if not e.is_delete_event]
await mark_events_deleted(self.relay_id, NostrFilter(ids=ids)) await mark_events_deleted(self.relay_id, NostrFilter(ids=ids))
async def _handle_request(self, subscription_id: str, filter: NostrFilter) -> List: async def _handle_request(self, subscription_id: str, filter: NostrFilter) -> List:
@ -255,14 +267,32 @@ class NostrClientConnection:
and len(self.filters) >= self.client_config.max_client_filters and len(self.filters) >= self.client_config.max_client_filters
) )
async def _validate_write(self, e: NostrEvent) -> Tuple[bool, str]: def _validate_auth_event(self, e: NostrEvent) -> Tuple[bool, str]:
valid, message = self._validate_event(e) valid, message = self._validate_event(e)
if not valid: if not valid:
return [valid, message] return [valid, message]
relay_tag = e.tag_values("relay")
challenge_tag = e.tag_values("challenge")
if len(relay_tag) == 0 or len(challenge_tag) == 0:
return False, "NIP42 tags are missing"
if self.client_config.domain != extract_domain(relay_tag[0]):
return False, "Wrong relay domain"
if self._auth_challenge != challenge_tag[0]:
return False, "Wrong chanlange value"
return True, ""
async def _validate_write(self, e: NostrEvent) -> Tuple[bool, str]:
valid, message = self._validate_event(e)
if not valid:
return (valid, message)
valid, message = await self._validate_storage(e.pubkey, e.size_bytes) valid, message = await self._validate_storage(e.pubkey, e.size_bytes)
if not valid: if not valid:
return [valid, message] return (valid, message)
return True, "" return True, ""

View file

@ -1,3 +1,5 @@
from urllib.parse import urlparse
from bech32 import bech32_decode, convertbits from bech32 import bech32_decode, convertbits
@ -17,3 +19,6 @@ def normalize_public_key(pubkey: str) -> str:
raise ValueError("Public Key is not valid hex") raise ValueError("Public Key is not valid hex")
int(pubkey, 16) int(pubkey, 16)
return pubkey return pubkey
def extract_domain(url: str) -> str:
return urlparse(url).netloc

View file

@ -1,6 +1,5 @@
from http import HTTPStatus from http import HTTPStatus
from typing import List, Optional from typing import List, Optional
from urllib.parse import urlparse
from fastapi import Depends, Request, WebSocket from fastapi import Depends, Request, WebSocket
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
@ -27,7 +26,7 @@ from .crud import (
get_relays, get_relays,
update_relay, update_relay,
) )
from .helpers import normalize_public_key from .helpers import extract_domain, normalize_public_key
from .models import BuyOrder, NostrRelay from .models import BuyOrder, NostrRelay
client_manager = NostrClientManager() client_manager = NostrClientManager()
@ -57,7 +56,7 @@ async def api_create_relay(
data.id = urlsafe_short_hash()[:8] data.id = urlsafe_short_hash()[:8]
try: try:
data.config.domain = urlparse(str(request.url)).netloc data.config.domain = extract_domain(str(request.url))
relay = await create_relay(wallet.wallet.user, data) relay = await create_relay(wallet.wallet.user, data)
return relay return relay