chore: code format

This commit is contained in:
Vlad Stan 2023-02-17 14:26:00 +02:00
parent 855812cb8f
commit 230729483c
7 changed files with 25 additions and 27 deletions

View file

@ -14,10 +14,10 @@ from ..crud import (
get_events, get_events,
mark_events_deleted, mark_events_deleted,
) )
from .relay import RelaySpec
from .event import NostrEvent, NostrEventType from .event import NostrEvent, NostrEventType
from .event_validator import EventValidator from .event_validator import EventValidator
from .filter import NostrFilter from .filter import NostrFilter
from .relay import RelaySpec
class NostrClientConnection: class NostrClientConnection:
@ -25,7 +25,7 @@ class NostrClientConnection:
self.websocket = websocket self.websocket = websocket
self.relay_id = relay_id self.relay_id = relay_id
self.filters: List[NostrFilter] = [] self.filters: List[NostrFilter] = []
self.pubkey: Optional[str] = None # set if authenticated self.pubkey: Optional[str] = None # set if authenticated
self._auth_challenge: Optional[str] = None self._auth_challenge: Optional[str] = None
self._auth_challenge_created_at = 0 self._auth_challenge_created_at = 0
@ -65,7 +65,6 @@ class NostrClientConnection:
setattr(self, "broadcast_event", broadcast_event) setattr(self, "broadcast_event", broadcast_event)
setattr(self, "get_client_config", get_client_config) setattr(self, "get_client_config", get_client_config)
setattr(self.event_validator, "get_client_config", get_client_config) setattr(self.event_validator, "get_client_config", get_client_config)
async def notify_event(self, event: NostrEvent) -> bool: async def notify_event(self, event: NostrEvent) -> bool:
if self._is_direct_message_for_other(event): if self._is_direct_message_for_other(event):
@ -80,8 +79,8 @@ class NostrClientConnection:
def _is_direct_message_for_other(self, event: NostrEvent) -> bool: def _is_direct_message_for_other(self, event: NostrEvent) -> bool:
""" """
Direct messages are not inteded to be boradcast (even if encrypted). Direct messages are not inteded to be boradcast (even if encrypted).
If the server requires AUTH for kind '4' then direct message will be sent only to the intended client. If the server requires AUTH for kind '4' then direct message will be sent only to the intended client.
""" """
if not event.is_direct_message: if not event.is_direct_message:
return False return False
@ -121,7 +120,9 @@ class NostrClientConnection:
resp_nip20: List[Any] = ["OK", e.id] resp_nip20: List[Any] = ["OK", e.id]
if e.is_auth_response_event: if e.is_auth_response_event:
valid, message = self.event_validator.validate_auth_event(e, self._auth_challenge) valid, message = self.event_validator.validate_auth_event(
e, self._auth_challenge
)
if not valid: if not valid:
resp_nip20 += [valid, message] resp_nip20 += [valid, message]
await self._send_msg(resp_nip20) await self._send_msg(resp_nip20)
@ -148,7 +149,8 @@ class NostrClientConnection:
try: try:
if e.is_replaceable_event: if e.is_replaceable_event:
await delete_events( await delete_events(
self.relay_id, NostrFilter(kinds=[e.kind], authors=[e.pubkey], until=e.created_at) self.relay_id,
NostrFilter(kinds=[e.kind], authors=[e.pubkey], until=e.created_at),
) )
if not e.is_ephemeral_event: if not e.is_ephemeral_event:
await create_event(self.relay_id, e, self.pubkey) await create_event(self.relay_id, e, self.pubkey)

View file

@ -1,9 +1,9 @@
from typing import List from typing import List
from ..crud import get_config_for_all_active_relays from ..crud import get_config_for_all_active_relays
from .relay import RelaySpec
from .client_connection import NostrClientConnection from .client_connection import NostrClientConnection
from .event import NostrEvent from .event import NostrEvent
from .relay import RelaySpec
class NostrClientManager: class NostrClientManager:
@ -69,4 +69,3 @@ class NostrClientManager:
setattr(client, "get_client_config", get_client_config) setattr(client, "get_client_config", get_client_config)
client.init_callbacks(self.broadcast_event, get_client_config) client.init_callbacks(self.broadcast_event, get_client_config)

View file

@ -65,7 +65,6 @@ class NostrEvent(BaseModel):
@property @property
def is_ephemeral_event(self) -> bool: def is_ephemeral_event(self) -> bool:
return self.kind >= 20000 and self.kind < 30000 return self.kind >= 20000 and self.kind < 30000
def check_signature(self): def check_signature(self):
event_id = self.event_id event_id = self.event_id
@ -101,4 +100,3 @@ class NostrEvent(BaseModel):
@classmethod @classmethod
def from_row(cls, row: Row) -> "NostrEvent": def from_row(cls, row: Row) -> "NostrEvent":
return cls(**dict(row)) return cls(**dict(row))

View file

@ -9,7 +9,6 @@ from .relay import RelaySpec
class EventValidator: class EventValidator:
def __init__(self, relay_id: str): def __init__(self, relay_id: str):
self.relay_id = relay_id self.relay_id = relay_id
@ -18,7 +17,9 @@ class EventValidator:
self.get_client_config: Optional[Callable[[], RelaySpec]] = None self.get_client_config: Optional[Callable[[], RelaySpec]] = None
async def validate_write(self, e: NostrEvent, publisher_pubkey: str) -> Tuple[bool, str]: async def validate_write(
self, e: NostrEvent, publisher_pubkey: str
) -> Tuple[bool, str]:
valid, message = self._validate_event(e) valid, message = self._validate_event(e)
if not valid: if not valid:
return (valid, message) return (valid, message)
@ -32,7 +33,9 @@ class EventValidator:
return True, "" return True, ""
def validate_auth_event(self, e: NostrEvent, auth_challenge: Optional[str]) -> Tuple[bool, str]: def validate_auth_event(
self, e: NostrEvent, auth_challenge: Optional[str]
) -> Tuple[bool, str]:
valid, message = self._validate_event(e) valid, message = self._validate_event(e)
if not valid: if not valid:
return (valid, message) return (valid, message)
@ -91,9 +94,7 @@ class EventValidator:
return False, f"This is a paid relay: '{self.relay_id}'" return False, f"This is a paid relay: '{self.relay_id}'"
stored_bytes = await get_storage_for_public_key(self.relay_id, pubkey) stored_bytes = await get_storage_for_public_key(self.relay_id, pubkey)
total_available_storage = ( total_available_storage = account.storage + self.config.free_storage_bytes_value
account.storage + self.config.free_storage_bytes_value
)
if (stored_bytes + event_size_bytes) <= total_available_storage: if (stored_bytes + event_size_bytes) <= total_available_storage:
return True, "" return True, ""
@ -110,7 +111,6 @@ class EventValidator:
return True, "" return True, ""
def _exceeded_max_events_per_hour(self) -> bool: def _exceeded_max_events_per_hour(self) -> bool:
if self.config.max_events_per_hour == 0: if self.config.max_events_per_hour == 0:
return False return False
@ -122,9 +122,7 @@ class EventValidator:
self._last_event_timestamp = current_time self._last_event_timestamp = current_time
self._event_count_per_timestamp = 0 self._event_count_per_timestamp = 0
return ( return self._event_count_per_timestamp > self.config.max_events_per_hour
self._event_count_per_timestamp > self.config.max_events_per_hour
)
def _created_at_in_range(self, created_at: int) -> Tuple[bool, str]: def _created_at_in_range(self, created_at: int) -> Tuple[bool, str]:
current_time = round(time.time()) current_time = round(time.time())
@ -134,4 +132,4 @@ class EventValidator:
if self.config.created_at_in_future != 0: if self.config.created_at_in_future != 0:
if created_at > (current_time + self.config.created_at_in_future): if created_at > (current_time + self.config.created_at_in_future):
return False, "created_at is too much into the future" return False, "created_at is too much into the future"
return True, "" return True, ""

View file

@ -1,4 +1,3 @@
from typing import Any, List, Optional, Tuple from typing import Any, List, Optional, Tuple
from pydantic import BaseModel, Field from pydantic import BaseModel, Field

View file

@ -5,7 +5,6 @@ from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class Spec(BaseModel): class Spec(BaseModel):
class Config: class Config:
allow_population_by_field_name = True allow_population_by_field_name = True
@ -127,4 +126,3 @@ class NostrRelay(BaseModel):
"software": "LNbits", "software": "LNbits",
"version": "", "version": "",
} }

View file

@ -6,9 +6,13 @@ import pytest
from fastapi import WebSocket from fastapi import WebSocket
from loguru import logger from loguru import logger
from lnbits.extensions.nostrrelay.relay.client_connection import (
NostrClientConnection, # type: ignore
)
from lnbits.extensions.nostrrelay.relay.client_manager import (
NostrClientManager, # type: ignore
)
from lnbits.extensions.nostrrelay.relay.relay import RelaySpec # type: ignore from lnbits.extensions.nostrrelay.relay.relay import RelaySpec # type: ignore
from lnbits.extensions.nostrrelay.relay.client_connection import NostrClientConnection # type: ignore
from lnbits.extensions.nostrrelay.relay.client_manager import NostrClientManager # type: ignore
from .helpers import get_fixtures from .helpers import get_fixtures