feat: improve codequality and CI (#25)
* feat: improve codequality and CI
This commit is contained in:
parent
28121184c3
commit
cc6752003a
28 changed files with 3114 additions and 292 deletions
|
|
@ -3,9 +3,8 @@ import time
|
|||
from typing import Any, Awaitable, Callable, List, Optional
|
||||
|
||||
from fastapi import WebSocket
|
||||
from loguru import logger
|
||||
|
||||
from lnbits.helpers import urlsafe_short_hash
|
||||
from loguru import logger
|
||||
|
||||
from ..crud import (
|
||||
NostrAccount,
|
||||
|
|
@ -55,26 +54,26 @@ class NostrClientConnection:
|
|||
message = reason if reason else "Server closed webocket"
|
||||
try:
|
||||
await self._send_msg(["NOTICE", message])
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
await self.websocket.close(reason=reason)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def init_callbacks(self, broadcast_event: Callable, get_client_config: Callable):
|
||||
setattr(self, "broadcast_event", broadcast_event)
|
||||
setattr(self, "get_client_config", get_client_config)
|
||||
setattr(self.event_validator, "get_client_config", get_client_config)
|
||||
self.broadcast_event = broadcast_event
|
||||
self.get_client_config = get_client_config
|
||||
self.event_validator.get_client_config = get_client_config
|
||||
|
||||
async def notify_event(self, event: NostrEvent) -> bool:
|
||||
if self._is_direct_message_for_other(event):
|
||||
return False
|
||||
|
||||
for filter in self.filters:
|
||||
if filter.matches(event):
|
||||
resp = event.serialize_response(filter.subscription_id)
|
||||
for nostr_filter in self.filters:
|
||||
if nostr_filter.matches(event):
|
||||
resp = event.serialize_response(nostr_filter.subscription_id)
|
||||
await self._send_msg(resp)
|
||||
return True
|
||||
return False
|
||||
|
|
@ -82,7 +81,8 @@ class NostrClientConnection:
|
|||
def _is_direct_message_for_other(self, event: NostrEvent) -> bool:
|
||||
"""
|
||||
Direct messages are not inteded to be boradcast (even if encrypted).
|
||||
If the server requires AUTH for kind '4' then direct message will be sent only to the intended client.
|
||||
If the server requires AUTH for kind '4' then direct message will be
|
||||
sent only to the intended client.
|
||||
"""
|
||||
if not event.is_direct_message:
|
||||
return False
|
||||
|
|
@ -136,7 +136,7 @@ class NostrClientConnection:
|
|||
await self._send_msg(["AUTH", self._current_auth_challenge()])
|
||||
resp_nip20 += [
|
||||
False,
|
||||
f"restricted: Relay requires authentication for events of kind '{e.kind}'",
|
||||
f"Relay requires authentication for events of kind '{e.kind}'",
|
||||
]
|
||||
await self._send_msg(resp_nip20)
|
||||
return None
|
||||
|
|
@ -166,7 +166,7 @@ class NostrClientConnection:
|
|||
event = await get_event(self.relay_id, e.id)
|
||||
# todo: handle NIP20 in detail
|
||||
message = "error: failed to create event"
|
||||
resp_nip20 += [event != None, message]
|
||||
resp_nip20 += [event is not None, message]
|
||||
|
||||
await self._send_msg(resp_nip20)
|
||||
|
||||
|
|
@ -181,13 +181,15 @@ class NostrClientConnection:
|
|||
|
||||
async def _handle_delete_event(self, event: NostrEvent):
|
||||
# NIP 09
|
||||
filter = NostrFilter(authors=[event.pubkey])
|
||||
filter.ids = [t[1] for t in event.tags if t[0] == "e"]
|
||||
events_to_delete = await get_events(self.relay_id, filter, False)
|
||||
nostr_filter = NostrFilter(authors=[event.pubkey])
|
||||
nostr_filter.ids = [t[1] for t in event.tags if t[0] == "e"]
|
||||
events_to_delete = await get_events(self.relay_id, nostr_filter, False)
|
||||
ids = [e.id for e in events_to_delete if not e.is_delete_event]
|
||||
await mark_events_deleted(self.relay_id, NostrFilter(ids=ids))
|
||||
|
||||
async def _handle_request(self, subscription_id: str, filter: NostrFilter) -> List:
|
||||
async def _handle_request(
|
||||
self, subscription_id: str, nostr_filter: NostrFilter
|
||||
) -> List:
|
||||
if self.config.require_auth_filter:
|
||||
if not self.auth_pubkey:
|
||||
return [["AUTH", self._current_auth_challenge()]]
|
||||
|
|
@ -199,26 +201,30 @@ class NostrClientConnection:
|
|||
return [
|
||||
[
|
||||
"NOTICE",
|
||||
f"Public key '{self.auth_pubkey}' is not allowed in relay '{self.relay_id}'!",
|
||||
(
|
||||
f"Public key '{self.auth_pubkey}' is not allowed "
|
||||
f"in relay '{self.relay_id}'!"
|
||||
),
|
||||
]
|
||||
]
|
||||
|
||||
if not account.can_join and not self.config.is_free_to_join:
|
||||
return [["NOTICE", f"This is a paid relay: '{self.relay_id}'"]]
|
||||
|
||||
filter.subscription_id = subscription_id
|
||||
nostr_filter.subscription_id = subscription_id
|
||||
self._remove_filter(subscription_id)
|
||||
if self._can_add_filter():
|
||||
max_filters = self.config.max_client_filters
|
||||
return [
|
||||
[
|
||||
"NOTICE",
|
||||
f"Maximum number of filters ({self.config.max_client_filters}) exceeded.",
|
||||
f"Maximum number of filters ({max_filters}) exceeded.",
|
||||
]
|
||||
]
|
||||
|
||||
filter.enforce_limit(self.config.limit_per_filter)
|
||||
self.filters.append(filter)
|
||||
events = await get_events(self.relay_id, filter)
|
||||
nostr_filter.enforce_limit(self.config.limit_per_filter)
|
||||
self.filters.append(nostr_filter)
|
||||
events = await get_events(self.relay_id, nostr_filter)
|
||||
events = [e for e in events if not self._is_direct_message_for_other(e)]
|
||||
serialized_events = [
|
||||
event.serialize_response(subscription_id) for event in events
|
||||
|
|
|
|||
|
|
@ -71,5 +71,5 @@ class NostrClientManager:
|
|||
def get_client_config() -> RelaySpec:
|
||||
return self.get_relay_config(client.relay_id)
|
||||
|
||||
setattr(client, "get_client_config", get_client_config)
|
||||
client.get_client_config = get_client_config
|
||||
client.init_callbacks(self.broadcast_event, get_client_config)
|
||||
|
|
|
|||
|
|
@ -34,8 +34,7 @@ class NostrEvent(BaseModel):
|
|||
@property
|
||||
def event_id(self) -> str:
|
||||
data = self.serialize_json()
|
||||
id = hashlib.sha256(data.encode()).hexdigest()
|
||||
return id
|
||||
return hashlib.sha256(data.encode()).hexdigest()
|
||||
|
||||
@property
|
||||
def size_bytes(self) -> int:
|
||||
|
|
@ -74,10 +73,10 @@ class NostrEvent(BaseModel):
|
|||
)
|
||||
try:
|
||||
pub_key = PublicKey(bytes.fromhex("02" + self.pubkey), True)
|
||||
except Exception:
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
f"Invalid public key: '{self.pubkey}' for event '{self.id}'"
|
||||
)
|
||||
) from exc
|
||||
|
||||
valid_signature = pub_key.schnorr_verify(
|
||||
bytes.fromhex(event_id), bytes.fromhex(self.sig), None, raw=True
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ class EventValidator:
|
|||
|
||||
def _validate_event(self, e: NostrEvent) -> Tuple[bool, str]:
|
||||
if self._exceeded_max_events_per_hour():
|
||||
return False, f"Exceeded max events per hour limit'!"
|
||||
return False, "Exceeded max events per hour limit'!"
|
||||
|
||||
try:
|
||||
e.check_signature()
|
||||
|
|
@ -101,7 +101,7 @@ class EventValidator:
|
|||
if self.config.full_storage_action == "block":
|
||||
return (
|
||||
False,
|
||||
f"Cannot write event, no more storage available for public key: '{pubkey}'",
|
||||
f"Cannot write event, no storage available for public key: '{pubkey}'",
|
||||
)
|
||||
|
||||
if event_size_bytes > total_available_storage:
|
||||
|
|
|
|||
|
|
@ -6,16 +6,15 @@ from .event import NostrEvent
|
|||
|
||||
|
||||
class NostrFilter(BaseModel):
|
||||
subscription_id: Optional[str]
|
||||
|
||||
e: List[str] = Field(default=[], alias="#e")
|
||||
p: List[str] = Field(default=[], alias="#p")
|
||||
ids: List[str] = []
|
||||
authors: List[str] = []
|
||||
kinds: List[int] = []
|
||||
e: List[str] = Field([], alias="#e")
|
||||
p: List[str] = Field([], alias="#p")
|
||||
since: Optional[int]
|
||||
until: Optional[int]
|
||||
limit: Optional[int]
|
||||
subscription_id: Optional[str] = None
|
||||
since: Optional[int] = None
|
||||
until: Optional[int] = None
|
||||
limit: Optional[int] = None
|
||||
|
||||
def matches(self, e: NostrEvent) -> bool:
|
||||
# todo: starts with
|
||||
|
|
@ -78,7 +77,8 @@ class NostrFilter(BaseModel):
|
|||
values += self.e
|
||||
e_s = ",".join(["?"] * len(self.e))
|
||||
inner_joins.append(
|
||||
"INNER JOIN nostrrelay.event_tags e_tags ON nostrrelay.events.id = e_tags.event_id"
|
||||
"INNER JOIN nostrrelay.event_tags e_tags "
|
||||
"ON nostrrelay.events.id = e_tags.event_id"
|
||||
)
|
||||
where.append(f" (e_tags.value in ({e_s}) AND e_tags.name = 'e')")
|
||||
|
||||
|
|
@ -86,7 +86,8 @@ class NostrFilter(BaseModel):
|
|||
values += self.p
|
||||
p_s = ",".join(["?"] * len(self.p))
|
||||
inner_joins.append(
|
||||
"INNER JOIN nostrrelay.event_tags p_tags ON nostrrelay.events.id = p_tags.event_id"
|
||||
"INNER JOIN nostrrelay.event_tags p_tags "
|
||||
"ON nostrrelay.events.id = p_tags.event_id"
|
||||
)
|
||||
where.append(f" p_tags.value in ({p_s}) AND p_tags.name = 'p'")
|
||||
|
||||
|
|
|
|||
|
|
@ -11,22 +11,22 @@ class Spec(BaseModel):
|
|||
|
||||
|
||||
class FilterSpec(Spec):
|
||||
max_client_filters = Field(0, alias="maxClientFilters")
|
||||
limit_per_filter = Field(1000, alias="limitPerFilter")
|
||||
max_client_filters: int = Field(default=0, alias="maxClientFilters")
|
||||
limit_per_filter: int = Field(default=1000, alias="limitPerFilter")
|
||||
|
||||
|
||||
class EventSpec(Spec):
|
||||
max_events_per_hour = Field(0, alias="maxEventsPerHour")
|
||||
max_events_per_hour: int = Field(default=0, alias="maxEventsPerHour")
|
||||
|
||||
created_at_days_past = Field(0, alias="createdAtDaysPast")
|
||||
created_at_hours_past = Field(0, alias="createdAtHoursPast")
|
||||
created_at_minutes_past = Field(0, alias="createdAtMinutesPast")
|
||||
created_at_seconds_past = Field(0, alias="createdAtSecondsPast")
|
||||
created_at_days_past: int = Field(default=0, alias="createdAtDaysPast")
|
||||
created_at_hours_past: int = Field(default=0, alias="createdAtHoursPast")
|
||||
created_at_minutes_past: int = Field(default=0, alias="createdAtMinutesPast")
|
||||
created_at_seconds_past: int = Field(default=0, alias="createdAtSecondsPast")
|
||||
|
||||
created_at_days_future = Field(0, alias="createdAtDaysFuture")
|
||||
created_at_hours_future = Field(0, alias="createdAtHoursFuture")
|
||||
created_at_minutes_future = Field(0, alias="createdAtMinutesFuture")
|
||||
created_at_seconds_future = Field(0, alias="createdAtSecondsFuture")
|
||||
created_at_days_future: int = Field(default=0, alias="createdAtDaysFuture")
|
||||
created_at_hours_future: int = Field(default=0, alias="createdAtHoursFuture")
|
||||
created_at_minutes_future: int = Field(default=0, alias="createdAtMinutesFuture")
|
||||
created_at_seconds_future: int = Field(default=0, alias="createdAtSecondsFuture")
|
||||
|
||||
@property
|
||||
def created_at_in_past(self) -> int:
|
||||
|
|
@ -48,9 +48,9 @@ class EventSpec(Spec):
|
|||
|
||||
|
||||
class StorageSpec(Spec):
|
||||
free_storage_value = Field(1, alias="freeStorageValue")
|
||||
free_storage_unit = Field("MB", alias="freeStorageUnit")
|
||||
full_storage_action = Field("prune", alias="fullStorageAction")
|
||||
free_storage_value: int = Field(default=1, alias="freeStorageValue")
|
||||
free_storage_unit: str = Field(default="MB", alias="freeStorageUnit")
|
||||
full_storage_action: str = Field(default="prune", alias="fullStorageAction")
|
||||
|
||||
@property
|
||||
def free_storage_bytes_value(self):
|
||||
|
|
@ -61,10 +61,10 @@ class StorageSpec(Spec):
|
|||
|
||||
|
||||
class AuthSpec(Spec):
|
||||
require_auth_events = Field(False, alias="requireAuthEvents")
|
||||
skiped_auth_events = Field([], alias="skipedAuthEvents")
|
||||
forced_auth_events = Field([], alias="forcedAuthEvents")
|
||||
require_auth_filter = Field(False, alias="requireAuthFilter")
|
||||
require_auth_events: bool = Field(default=False, alias="requireAuthEvents")
|
||||
skiped_auth_events: list = Field(default=[], alias="skipedAuthEvents")
|
||||
forced_auth_events: list = Field(default=[], alias="forcedAuthEvents")
|
||||
require_auth_filter: bool = Field(default=False, alias="requireAuthFilter")
|
||||
|
||||
def event_requires_auth(self, kind: int) -> bool:
|
||||
if self.require_auth_events:
|
||||
|
|
@ -73,11 +73,11 @@ class AuthSpec(Spec):
|
|||
|
||||
|
||||
class PaymentSpec(Spec):
|
||||
is_paid_relay = Field(False, alias="isPaidRelay")
|
||||
cost_to_join = Field(0, alias="costToJoin")
|
||||
is_paid_relay: bool = Field(default=False, alias="isPaidRelay")
|
||||
cost_to_join: int = Field(default=0, alias="costToJoin")
|
||||
|
||||
storage_cost_value = Field(0, alias="storageCostValue")
|
||||
storage_cost_unit = Field("MB", alias="storageCostUnit")
|
||||
storage_cost_value: int = Field(default=0, alias="storageCostValue")
|
||||
storage_cost_unit: str = Field(default="MB", alias="storageCostUnit")
|
||||
|
||||
@property
|
||||
def is_free_to_join(self):
|
||||
|
|
@ -85,7 +85,7 @@ class PaymentSpec(Spec):
|
|||
|
||||
|
||||
class WalletSpec(Spec):
|
||||
wallet = Field("")
|
||||
wallet: str = Field(default="")
|
||||
|
||||
|
||||
class RelayPublicSpec(FilterSpec, EventSpec, StorageSpec, PaymentSpec):
|
||||
|
|
@ -93,7 +93,7 @@ class RelayPublicSpec(FilterSpec, EventSpec, StorageSpec, PaymentSpec):
|
|||
|
||||
@property
|
||||
def is_read_only_relay(self):
|
||||
self.free_storage_value == 0 and not self.is_paid_relay
|
||||
return self.free_storage_value == 0 and not self.is_paid_relay
|
||||
|
||||
|
||||
class RelaySpec(RelayPublicSpec, WalletSpec, AuthSpec):
|
||||
|
|
@ -108,7 +108,7 @@ class NostrRelay(BaseModel):
|
|||
contact: Optional[str]
|
||||
active: bool = False
|
||||
|
||||
config: "RelaySpec" = RelaySpec()
|
||||
config = RelaySpec()
|
||||
|
||||
@property
|
||||
def is_free_to_join(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue