feat: improve codequality and CI (#25)

* feat: improve codequality and CI
This commit is contained in:
dni ⚡ 2024-08-30 13:20:23 +02:00 committed by GitHub
parent 28121184c3
commit cc6752003a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 3114 additions and 292 deletions

View file

@ -3,9 +3,8 @@ import time
from typing import Any, Awaitable, Callable, List, Optional
from fastapi import WebSocket
from loguru import logger
from lnbits.helpers import urlsafe_short_hash
from loguru import logger
from ..crud import (
NostrAccount,
@ -55,26 +54,26 @@ class NostrClientConnection:
message = reason if reason else "Server closed webocket"
try:
await self._send_msg(["NOTICE", message])
except:
except Exception:
pass
try:
await self.websocket.close(reason=reason)
except:
except Exception:
pass
def init_callbacks(self, broadcast_event: Callable, get_client_config: Callable):
setattr(self, "broadcast_event", broadcast_event)
setattr(self, "get_client_config", get_client_config)
setattr(self.event_validator, "get_client_config", get_client_config)
self.broadcast_event = broadcast_event
self.get_client_config = get_client_config
self.event_validator.get_client_config = get_client_config
async def notify_event(self, event: NostrEvent) -> bool:
if self._is_direct_message_for_other(event):
return False
for filter in self.filters:
if filter.matches(event):
resp = event.serialize_response(filter.subscription_id)
for nostr_filter in self.filters:
if nostr_filter.matches(event):
resp = event.serialize_response(nostr_filter.subscription_id)
await self._send_msg(resp)
return True
return False
@ -82,7 +81,8 @@ class NostrClientConnection:
def _is_direct_message_for_other(self, event: NostrEvent) -> bool:
"""
Direct messages are not inteded to be boradcast (even if encrypted).
If the server requires AUTH for kind '4' then direct message will be sent only to the intended client.
If the server requires AUTH for kind '4' then direct message will be
sent only to the intended client.
"""
if not event.is_direct_message:
return False
@ -136,7 +136,7 @@ class NostrClientConnection:
await self._send_msg(["AUTH", self._current_auth_challenge()])
resp_nip20 += [
False,
f"restricted: Relay requires authentication for events of kind '{e.kind}'",
f"Relay requires authentication for events of kind '{e.kind}'",
]
await self._send_msg(resp_nip20)
return None
@ -166,7 +166,7 @@ class NostrClientConnection:
event = await get_event(self.relay_id, e.id)
# todo: handle NIP20 in detail
message = "error: failed to create event"
resp_nip20 += [event != None, message]
resp_nip20 += [event is not None, message]
await self._send_msg(resp_nip20)
@ -181,13 +181,15 @@ class NostrClientConnection:
async def _handle_delete_event(self, event: NostrEvent):
# NIP 09
filter = NostrFilter(authors=[event.pubkey])
filter.ids = [t[1] for t in event.tags if t[0] == "e"]
events_to_delete = await get_events(self.relay_id, filter, False)
nostr_filter = NostrFilter(authors=[event.pubkey])
nostr_filter.ids = [t[1] for t in event.tags if t[0] == "e"]
events_to_delete = await get_events(self.relay_id, nostr_filter, False)
ids = [e.id for e in events_to_delete if not e.is_delete_event]
await mark_events_deleted(self.relay_id, NostrFilter(ids=ids))
async def _handle_request(self, subscription_id: str, filter: NostrFilter) -> List:
async def _handle_request(
self, subscription_id: str, nostr_filter: NostrFilter
) -> List:
if self.config.require_auth_filter:
if not self.auth_pubkey:
return [["AUTH", self._current_auth_challenge()]]
@ -199,26 +201,30 @@ class NostrClientConnection:
return [
[
"NOTICE",
f"Public key '{self.auth_pubkey}' is not allowed in relay '{self.relay_id}'!",
(
f"Public key '{self.auth_pubkey}' is not allowed "
f"in relay '{self.relay_id}'!"
),
]
]
if not account.can_join and not self.config.is_free_to_join:
return [["NOTICE", f"This is a paid relay: '{self.relay_id}'"]]
filter.subscription_id = subscription_id
nostr_filter.subscription_id = subscription_id
self._remove_filter(subscription_id)
if self._can_add_filter():
max_filters = self.config.max_client_filters
return [
[
"NOTICE",
f"Maximum number of filters ({self.config.max_client_filters}) exceeded.",
f"Maximum number of filters ({max_filters}) exceeded.",
]
]
filter.enforce_limit(self.config.limit_per_filter)
self.filters.append(filter)
events = await get_events(self.relay_id, filter)
nostr_filter.enforce_limit(self.config.limit_per_filter)
self.filters.append(nostr_filter)
events = await get_events(self.relay_id, nostr_filter)
events = [e for e in events if not self._is_direct_message_for_other(e)]
serialized_events = [
event.serialize_response(subscription_id) for event in events

View file

@ -71,5 +71,5 @@ class NostrClientManager:
def get_client_config() -> RelaySpec:
return self.get_relay_config(client.relay_id)
setattr(client, "get_client_config", get_client_config)
client.get_client_config = get_client_config
client.init_callbacks(self.broadcast_event, get_client_config)

View file

@ -34,8 +34,7 @@ class NostrEvent(BaseModel):
@property
def event_id(self) -> str:
data = self.serialize_json()
id = hashlib.sha256(data.encode()).hexdigest()
return id
return hashlib.sha256(data.encode()).hexdigest()
@property
def size_bytes(self) -> int:
@ -74,10 +73,10 @@ class NostrEvent(BaseModel):
)
try:
pub_key = PublicKey(bytes.fromhex("02" + self.pubkey), True)
except Exception:
except Exception as exc:
raise ValueError(
f"Invalid public key: '{self.pubkey}' for event '{self.id}'"
)
) from exc
valid_signature = pub_key.schnorr_verify(
bytes.fromhex(event_id), bytes.fromhex(self.sig), None, raw=True

View file

@ -61,7 +61,7 @@ class EventValidator:
def _validate_event(self, e: NostrEvent) -> Tuple[bool, str]:
if self._exceeded_max_events_per_hour():
return False, f"Exceeded max events per hour limit'!"
return False, "Exceeded max events per hour limit'!"
try:
e.check_signature()
@ -101,7 +101,7 @@ class EventValidator:
if self.config.full_storage_action == "block":
return (
False,
f"Cannot write event, no more storage available for public key: '{pubkey}'",
f"Cannot write event, no storage available for public key: '{pubkey}'",
)
if event_size_bytes > total_available_storage:

View file

@ -6,16 +6,15 @@ from .event import NostrEvent
class NostrFilter(BaseModel):
subscription_id: Optional[str]
e: List[str] = Field(default=[], alias="#e")
p: List[str] = Field(default=[], alias="#p")
ids: List[str] = []
authors: List[str] = []
kinds: List[int] = []
e: List[str] = Field([], alias="#e")
p: List[str] = Field([], alias="#p")
since: Optional[int]
until: Optional[int]
limit: Optional[int]
subscription_id: Optional[str] = None
since: Optional[int] = None
until: Optional[int] = None
limit: Optional[int] = None
def matches(self, e: NostrEvent) -> bool:
# todo: starts with
@ -78,7 +77,8 @@ class NostrFilter(BaseModel):
values += self.e
e_s = ",".join(["?"] * len(self.e))
inner_joins.append(
"INNER JOIN nostrrelay.event_tags e_tags ON nostrrelay.events.id = e_tags.event_id"
"INNER JOIN nostrrelay.event_tags e_tags "
"ON nostrrelay.events.id = e_tags.event_id"
)
where.append(f" (e_tags.value in ({e_s}) AND e_tags.name = 'e')")
@ -86,7 +86,8 @@ class NostrFilter(BaseModel):
values += self.p
p_s = ",".join(["?"] * len(self.p))
inner_joins.append(
"INNER JOIN nostrrelay.event_tags p_tags ON nostrrelay.events.id = p_tags.event_id"
"INNER JOIN nostrrelay.event_tags p_tags "
"ON nostrrelay.events.id = p_tags.event_id"
)
where.append(f" p_tags.value in ({p_s}) AND p_tags.name = 'p'")

View file

@ -11,22 +11,22 @@ class Spec(BaseModel):
class FilterSpec(Spec):
max_client_filters = Field(0, alias="maxClientFilters")
limit_per_filter = Field(1000, alias="limitPerFilter")
max_client_filters: int = Field(default=0, alias="maxClientFilters")
limit_per_filter: int = Field(default=1000, alias="limitPerFilter")
class EventSpec(Spec):
max_events_per_hour = Field(0, alias="maxEventsPerHour")
max_events_per_hour: int = Field(default=0, alias="maxEventsPerHour")
created_at_days_past = Field(0, alias="createdAtDaysPast")
created_at_hours_past = Field(0, alias="createdAtHoursPast")
created_at_minutes_past = Field(0, alias="createdAtMinutesPast")
created_at_seconds_past = Field(0, alias="createdAtSecondsPast")
created_at_days_past: int = Field(default=0, alias="createdAtDaysPast")
created_at_hours_past: int = Field(default=0, alias="createdAtHoursPast")
created_at_minutes_past: int = Field(default=0, alias="createdAtMinutesPast")
created_at_seconds_past: int = Field(default=0, alias="createdAtSecondsPast")
created_at_days_future = Field(0, alias="createdAtDaysFuture")
created_at_hours_future = Field(0, alias="createdAtHoursFuture")
created_at_minutes_future = Field(0, alias="createdAtMinutesFuture")
created_at_seconds_future = Field(0, alias="createdAtSecondsFuture")
created_at_days_future: int = Field(default=0, alias="createdAtDaysFuture")
created_at_hours_future: int = Field(default=0, alias="createdAtHoursFuture")
created_at_minutes_future: int = Field(default=0, alias="createdAtMinutesFuture")
created_at_seconds_future: int = Field(default=0, alias="createdAtSecondsFuture")
@property
def created_at_in_past(self) -> int:
@ -48,9 +48,9 @@ class EventSpec(Spec):
class StorageSpec(Spec):
free_storage_value = Field(1, alias="freeStorageValue")
free_storage_unit = Field("MB", alias="freeStorageUnit")
full_storage_action = Field("prune", alias="fullStorageAction")
free_storage_value: int = Field(default=1, alias="freeStorageValue")
free_storage_unit: str = Field(default="MB", alias="freeStorageUnit")
full_storage_action: str = Field(default="prune", alias="fullStorageAction")
@property
def free_storage_bytes_value(self):
@ -61,10 +61,10 @@ class StorageSpec(Spec):
class AuthSpec(Spec):
require_auth_events = Field(False, alias="requireAuthEvents")
skiped_auth_events = Field([], alias="skipedAuthEvents")
forced_auth_events = Field([], alias="forcedAuthEvents")
require_auth_filter = Field(False, alias="requireAuthFilter")
require_auth_events: bool = Field(default=False, alias="requireAuthEvents")
skiped_auth_events: list = Field(default=[], alias="skipedAuthEvents")
forced_auth_events: list = Field(default=[], alias="forcedAuthEvents")
require_auth_filter: bool = Field(default=False, alias="requireAuthFilter")
def event_requires_auth(self, kind: int) -> bool:
if self.require_auth_events:
@ -73,11 +73,11 @@ class AuthSpec(Spec):
class PaymentSpec(Spec):
is_paid_relay = Field(False, alias="isPaidRelay")
cost_to_join = Field(0, alias="costToJoin")
is_paid_relay: bool = Field(default=False, alias="isPaidRelay")
cost_to_join: int = Field(default=0, alias="costToJoin")
storage_cost_value = Field(0, alias="storageCostValue")
storage_cost_unit = Field("MB", alias="storageCostUnit")
storage_cost_value: int = Field(default=0, alias="storageCostValue")
storage_cost_unit: str = Field(default="MB", alias="storageCostUnit")
@property
def is_free_to_join(self):
@ -85,7 +85,7 @@ class PaymentSpec(Spec):
class WalletSpec(Spec):
wallet = Field("")
wallet: str = Field(default="")
class RelayPublicSpec(FilterSpec, EventSpec, StorageSpec, PaymentSpec):
@ -93,7 +93,7 @@ class RelayPublicSpec(FilterSpec, EventSpec, StorageSpec, PaymentSpec):
@property
def is_read_only_relay(self):
self.free_storage_value == 0 and not self.is_paid_relay
return self.free_storage_value == 0 and not self.is_paid_relay
class RelaySpec(RelayPublicSpec, WalletSpec, AuthSpec):
@ -108,7 +108,7 @@ class NostrRelay(BaseModel):
contact: Optional[str]
active: bool = False
config: "RelaySpec" = RelaySpec()
config = RelaySpec()
@property
def is_free_to_join(self):