refactor: extract EventValidator

This commit is contained in:
Vlad Stan 2023-02-17 14:00:19 +02:00
parent c46c903703
commit aa68d2a79a
4 changed files with 140 additions and 136 deletions

View file

@ -1,6 +1,6 @@
import json import json
import time import time
from typing import Any, Awaitable, Callable, List, Optional, Tuple from typing import Any, Awaitable, Callable, List, Optional
from fastapi import WebSocket from fastapi import WebSocket
from loguru import logger from loguru import logger
@ -10,16 +10,13 @@ from lnbits.helpers import urlsafe_short_hash
from ..crud import ( from ..crud import (
create_event, create_event,
delete_events, delete_events,
get_account,
get_config_for_all_active_relays, get_config_for_all_active_relays,
get_event, get_event,
get_events, get_events,
get_storage_for_public_key,
mark_events_deleted, mark_events_deleted,
prune_old_events,
) )
from ..helpers import extract_domain from ..models import NostrEvent, NostrEventType, NostrFilter, RelaySpec
from ..models import NostrAccount, NostrEvent, NostrEventType, NostrFilter, RelaySpec from .event_validator import EventValidator
class NostrClientManager: class NostrClientManager:
@ -306,131 +303,3 @@ class NostrClientConnection:
return self._auth_challenge return self._auth_challenge
class EventValidator:
def __init__(self, relay_id: str):
self.relay_id = relay_id
self._last_event_timestamp = 0 # in hours
self._event_count_per_timestamp = 0
self.get_client_config: Optional[Callable[[], RelaySpec]] = None
async def validate_write(self, e: NostrEvent, publisher_pubkey: str) -> Tuple[bool, str]:
valid, message = self._validate_event(e)
if not valid:
return (valid, message)
if e.is_ephemeral_event:
return True, ""
valid, message = await self._validate_storage(publisher_pubkey, e.size_bytes)
if not valid:
return (valid, message)
return True, ""
def validate_auth_event(self, e: NostrEvent, auth_challenge: Optional[str]) -> Tuple[bool, str]:
valid, message = self._validate_event(e)
if not valid:
return (valid, message)
relay_tag = e.tag_values("relay")
challenge_tag = e.tag_values("challenge")
if len(relay_tag) == 0 or len(challenge_tag) == 0:
return False, "error: NIP42 tags are missing for auth event"
if self.config.domain != extract_domain(relay_tag[0]):
return False, "error: wrong relay domain for auth event"
if auth_challenge != challenge_tag[0]:
return False, "error: wrong chanlange value for auth event"
return True, ""
@property
def config(self) -> RelaySpec:
if not self.get_client_config:
raise Exception("EventValidator not ready!")
return self.get_client_config()
def _validate_event(self, e: NostrEvent) -> Tuple[bool, str]:
if self._exceeded_max_events_per_hour():
return False, f"Exceeded max events per hour limit'!"
try:
e.check_signature()
except ValueError:
return False, "invalid: wrong event `id` or `sig`"
in_range, message = self._created_at_in_range(e.created_at)
if not in_range:
return False, message
return True, ""
async def _validate_storage(
self, pubkey: str, event_size_bytes: int
) -> Tuple[bool, str]:
if self.config.is_read_only_relay:
return False, "Cannot write event, relay is read-only"
account = await get_account(self.relay_id, pubkey)
if not account:
account = NostrAccount.null_account()
if account.blocked:
return (
False,
f"Public key '{pubkey}' is not allowed in relay '{self.relay_id}'!",
)
if not account.can_join and self.config.is_paid_relay:
return False, f"This is a paid relay: '{self.relay_id}'"
stored_bytes = await get_storage_for_public_key(self.relay_id, pubkey)
total_available_storage = (
account.storage + self.config.free_storage_bytes_value
)
if (stored_bytes + event_size_bytes) <= total_available_storage:
return True, ""
if self.config.full_storage_action == "block":
return (
False,
f"Cannot write event, no more storage available for public key: '{pubkey}'",
)
if event_size_bytes > total_available_storage:
return False, "Message is too large. Not enough storage available for it."
await prune_old_events(self.relay_id, pubkey, event_size_bytes)
return True, ""
def _exceeded_max_events_per_hour(self) -> bool:
if self.config.max_events_per_hour == 0:
return False
current_time = round(time.time() / 3600)
if self._last_event_timestamp == current_time:
self._event_count_per_timestamp += 1
else:
self._last_event_timestamp = current_time
self._event_count_per_timestamp = 0
return (
self._event_count_per_timestamp > self.config.max_events_per_hour
)
def _created_at_in_range(self, created_at: int) -> Tuple[bool, str]:
current_time = round(time.time())
if self.config.created_at_in_past != 0:
if created_at < (current_time - self.config.created_at_in_past):
return False, "created_at is too much into the past"
if self.config.created_at_in_future != 0:
if created_at > (current_time + self.config.created_at_in_future):
return False, "created_at is too much into the future"
return True, ""

135
relay/event_validator.py Normal file
View file

@ -0,0 +1,135 @@
import time
from typing import Callable, Optional, Tuple
from ..crud import get_account, get_storage_for_public_key, prune_old_events
from ..helpers import extract_domain
from ..models import NostrAccount, NostrEvent, RelaySpec
class EventValidator:
def __init__(self, relay_id: str):
self.relay_id = relay_id
self._last_event_timestamp = 0 # in hours
self._event_count_per_timestamp = 0
self.get_client_config: Optional[Callable[[], RelaySpec]] = None
async def validate_write(self, e: NostrEvent, publisher_pubkey: str) -> Tuple[bool, str]:
valid, message = self._validate_event(e)
if not valid:
return (valid, message)
if e.is_ephemeral_event:
return True, ""
valid, message = await self._validate_storage(publisher_pubkey, e.size_bytes)
if not valid:
return (valid, message)
return True, ""
def validate_auth_event(self, e: NostrEvent, auth_challenge: Optional[str]) -> Tuple[bool, str]:
valid, message = self._validate_event(e)
if not valid:
return (valid, message)
relay_tag = e.tag_values("relay")
challenge_tag = e.tag_values("challenge")
if len(relay_tag) == 0 or len(challenge_tag) == 0:
return False, "error: NIP42 tags are missing for auth event"
if self.config.domain != extract_domain(relay_tag[0]):
return False, "error: wrong relay domain for auth event"
if auth_challenge != challenge_tag[0]:
return False, "error: wrong chanlange value for auth event"
return True, ""
@property
def config(self) -> RelaySpec:
if not self.get_client_config:
raise Exception("EventValidator not ready!")
return self.get_client_config()
def _validate_event(self, e: NostrEvent) -> Tuple[bool, str]:
if self._exceeded_max_events_per_hour():
return False, f"Exceeded max events per hour limit'!"
try:
e.check_signature()
except ValueError:
return False, "invalid: wrong event `id` or `sig`"
in_range, message = self._created_at_in_range(e.created_at)
if not in_range:
return False, message
return True, ""
async def _validate_storage(
self, pubkey: str, event_size_bytes: int
) -> Tuple[bool, str]:
if self.config.is_read_only_relay:
return False, "Cannot write event, relay is read-only"
account = await get_account(self.relay_id, pubkey)
if not account:
account = NostrAccount.null_account()
if account.blocked:
return (
False,
f"Public key '{pubkey}' is not allowed in relay '{self.relay_id}'!",
)
if not account.can_join and self.config.is_paid_relay:
return False, f"This is a paid relay: '{self.relay_id}'"
stored_bytes = await get_storage_for_public_key(self.relay_id, pubkey)
total_available_storage = (
account.storage + self.config.free_storage_bytes_value
)
if (stored_bytes + event_size_bytes) <= total_available_storage:
return True, ""
if self.config.full_storage_action == "block":
return (
False,
f"Cannot write event, no more storage available for public key: '{pubkey}'",
)
if event_size_bytes > total_available_storage:
return False, "Message is too large. Not enough storage available for it."
await prune_old_events(self.relay_id, pubkey, event_size_bytes)
return True, ""
def _exceeded_max_events_per_hour(self) -> bool:
if self.config.max_events_per_hour == 0:
return False
current_time = round(time.time() / 3600)
if self._last_event_timestamp == current_time:
self._event_count_per_timestamp += 1
else:
self._last_event_timestamp = current_time
self._event_count_per_timestamp = 0
return (
self._event_count_per_timestamp > self.config.max_events_per_hour
)
def _created_at_in_range(self, created_at: int) -> Tuple[bool, str]:
current_time = round(time.time())
if self.config.created_at_in_past != 0:
if created_at < (current_time - self.config.created_at_in_past):
return False, "created_at is too much into the past"
if self.config.created_at_in_future != 0:
if created_at > (current_time + self.config.created_at_in_future):
return False, "created_at is too much into the future"
return True, ""

View file

@ -6,11 +6,11 @@ import pytest
from fastapi import WebSocket from fastapi import WebSocket
from loguru import logger from loguru import logger
from lnbits.extensions.nostrrelay.models import RelaySpec # type: ignore
from lnbits.extensions.nostrrelay.relay.client_manager import ( # type: ignore from lnbits.extensions.nostrrelay.relay.client_manager import ( # type: ignore
NostrClientConnection, NostrClientConnection,
NostrClientManager, NostrClientManager,
) )
from lnbits.extensions.nostrrelay.models import RelaySpec # type: ignore
from .helpers import get_fixtures from .helpers import get_fixtures

View file

@ -16,7 +16,6 @@ from lnbits.decorators import (
from lnbits.helpers import urlsafe_short_hash from lnbits.helpers import urlsafe_short_hash
from . import nostrrelay_ext from . import nostrrelay_ext
from .relay.client_manager import NostrClientConnection, NostrClientManager
from .crud import ( from .crud import (
create_account, create_account,
create_relay, create_relay,
@ -32,6 +31,7 @@ from .crud import (
) )
from .helpers import extract_domain, normalize_public_key from .helpers import extract_domain, normalize_public_key
from .models import BuyOrder, NostrAccount, NostrPartialAccount, NostrRelay from .models import BuyOrder, NostrAccount, NostrPartialAccount, NostrRelay
from .relay.client_manager import NostrClientConnection, NostrClientManager
client_manager = NostrClientManager() client_manager = NostrClientManager()