diff --git a/relay/client_connection.py b/relay/client_connection.py index 5af1206..92a70ed 100644 --- a/relay/client_connection.py +++ b/relay/client_connection.py @@ -14,10 +14,10 @@ from ..crud import ( get_events, mark_events_deleted, ) -from .relay import RelaySpec from .event import NostrEvent, NostrEventType from .event_validator import EventValidator from .filter import NostrFilter +from .relay import RelaySpec class NostrClientConnection: @@ -25,7 +25,7 @@ class NostrClientConnection: self.websocket = websocket self.relay_id = relay_id self.filters: List[NostrFilter] = [] - self.pubkey: Optional[str] = None # set if authenticated + self.pubkey: Optional[str] = None # set if authenticated self._auth_challenge: Optional[str] = None self._auth_challenge_created_at = 0 @@ -65,7 +65,6 @@ class NostrClientConnection: setattr(self, "broadcast_event", broadcast_event) setattr(self, "get_client_config", get_client_config) setattr(self.event_validator, "get_client_config", get_client_config) - async def notify_event(self, event: NostrEvent) -> bool: if self._is_direct_message_for_other(event): @@ -80,8 +79,8 @@ class NostrClientConnection: def _is_direct_message_for_other(self, event: NostrEvent) -> bool: """ - Direct messages are not inteded to be boradcast (even if encrypted). - If the server requires AUTH for kind '4' then direct message will be sent only to the intended client. + Direct messages are not inteded to be boradcast (even if encrypted). + If the server requires AUTH for kind '4' then direct message will be sent only to the intended client. """ if not event.is_direct_message: return False @@ -121,7 +120,9 @@ class NostrClientConnection: resp_nip20: List[Any] = ["OK", e.id] if e.is_auth_response_event: - valid, message = self.event_validator.validate_auth_event(e, self._auth_challenge) + valid, message = self.event_validator.validate_auth_event( + e, self._auth_challenge + ) if not valid: resp_nip20 += [valid, message] await self._send_msg(resp_nip20) @@ -148,7 +149,8 @@ class NostrClientConnection: try: if e.is_replaceable_event: await delete_events( - self.relay_id, NostrFilter(kinds=[e.kind], authors=[e.pubkey], until=e.created_at) + self.relay_id, + NostrFilter(kinds=[e.kind], authors=[e.pubkey], until=e.created_at), ) if not e.is_ephemeral_event: await create_event(self.relay_id, e, self.pubkey) diff --git a/relay/client_manager.py b/relay/client_manager.py index 33ecd96..c2db58d 100644 --- a/relay/client_manager.py +++ b/relay/client_manager.py @@ -1,9 +1,9 @@ from typing import List from ..crud import get_config_for_all_active_relays -from .relay import RelaySpec from .client_connection import NostrClientConnection from .event import NostrEvent +from .relay import RelaySpec class NostrClientManager: @@ -69,4 +69,3 @@ class NostrClientManager: setattr(client, "get_client_config", get_client_config) client.init_callbacks(self.broadcast_event, get_client_config) - diff --git a/relay/event.py b/relay/event.py index 15b3c02..1a70e2e 100644 --- a/relay/event.py +++ b/relay/event.py @@ -65,7 +65,6 @@ class NostrEvent(BaseModel): @property def is_ephemeral_event(self) -> bool: return self.kind >= 20000 and self.kind < 30000 - def check_signature(self): event_id = self.event_id @@ -101,4 +100,3 @@ class NostrEvent(BaseModel): @classmethod def from_row(cls, row: Row) -> "NostrEvent": return cls(**dict(row)) - diff --git a/relay/event_validator.py b/relay/event_validator.py index 02c1c61..3b620bf 100644 --- a/relay/event_validator.py +++ b/relay/event_validator.py @@ -9,7 +9,6 @@ from .relay import RelaySpec class EventValidator: - def __init__(self, relay_id: str): self.relay_id = relay_id @@ -18,7 +17,9 @@ class EventValidator: self.get_client_config: Optional[Callable[[], RelaySpec]] = None - async def validate_write(self, e: NostrEvent, publisher_pubkey: str) -> Tuple[bool, str]: + async def validate_write( + self, e: NostrEvent, publisher_pubkey: str + ) -> Tuple[bool, str]: valid, message = self._validate_event(e) if not valid: return (valid, message) @@ -32,7 +33,9 @@ class EventValidator: return True, "" - def validate_auth_event(self, e: NostrEvent, auth_challenge: Optional[str]) -> Tuple[bool, str]: + def validate_auth_event( + self, e: NostrEvent, auth_challenge: Optional[str] + ) -> Tuple[bool, str]: valid, message = self._validate_event(e) if not valid: return (valid, message) @@ -91,9 +94,7 @@ class EventValidator: return False, f"This is a paid relay: '{self.relay_id}'" stored_bytes = await get_storage_for_public_key(self.relay_id, pubkey) - total_available_storage = ( - account.storage + self.config.free_storage_bytes_value - ) + total_available_storage = account.storage + self.config.free_storage_bytes_value if (stored_bytes + event_size_bytes) <= total_available_storage: return True, "" @@ -110,7 +111,6 @@ class EventValidator: return True, "" - def _exceeded_max_events_per_hour(self) -> bool: if self.config.max_events_per_hour == 0: return False @@ -122,9 +122,7 @@ class EventValidator: self._last_event_timestamp = current_time self._event_count_per_timestamp = 0 - return ( - self._event_count_per_timestamp > self.config.max_events_per_hour - ) + return self._event_count_per_timestamp > self.config.max_events_per_hour def _created_at_in_range(self, created_at: int) -> Tuple[bool, str]: current_time = round(time.time()) @@ -134,4 +132,4 @@ class EventValidator: if self.config.created_at_in_future != 0: if created_at > (current_time + self.config.created_at_in_future): return False, "created_at is too much into the future" - return True, "" \ No newline at end of file + return True, "" diff --git a/relay/filter.py b/relay/filter.py index 68f740d..f423eff 100644 --- a/relay/filter.py +++ b/relay/filter.py @@ -1,4 +1,3 @@ - from typing import Any, List, Optional, Tuple from pydantic import BaseModel, Field diff --git a/relay/relay.py b/relay/relay.py index 2d7e876..71f2f24 100644 --- a/relay/relay.py +++ b/relay/relay.py @@ -5,7 +5,6 @@ from typing import Optional from pydantic import BaseModel, Field - class Spec(BaseModel): class Config: allow_population_by_field_name = True @@ -127,4 +126,3 @@ class NostrRelay(BaseModel): "software": "LNbits", "version": "", } - diff --git a/tests/test_clients.py b/tests/test_clients.py index b26ec26..e79c3bf 100644 --- a/tests/test_clients.py +++ b/tests/test_clients.py @@ -6,9 +6,13 @@ import pytest from fastapi import WebSocket from loguru import logger +from lnbits.extensions.nostrrelay.relay.client_connection import ( + NostrClientConnection, # type: ignore +) +from lnbits.extensions.nostrrelay.relay.client_manager import ( + NostrClientManager, # type: ignore +) from lnbits.extensions.nostrrelay.relay.relay import RelaySpec # type: ignore -from lnbits.extensions.nostrrelay.relay.client_connection import NostrClientConnection # type: ignore -from lnbits.extensions.nostrrelay.relay.client_manager import NostrClientManager # type: ignore from .helpers import get_fixtures