diff --git a/client_manager.py b/client_manager.py index 45c63fe..a717a48 100644 --- a/client_manager.py +++ b/client_manager.py @@ -4,7 +4,7 @@ from typing import Any, Callable, List from fastapi import WebSocket from loguru import logger -from .crud import create_event, mark_events_deleted, get_event, get_events +from .crud import create_event, get_event, get_events, mark_events_deleted from .models import NostrEvent, NostrEventType, NostrFilter diff --git a/crud.py b/crud.py index 6139fd9..2711aac 100644 --- a/crud.py +++ b/crud.py @@ -54,11 +54,11 @@ async def get_event(relay_id: str, id: str) -> Optional[NostrEvent]: return event async def mark_events_deleted(relay_id: str, filter: NostrFilter): - if len(filter.ids) == 0: + if filter.is_empty(): return None - ids = ",".join(["?"] * len(filter.ids)) - values = [relay_id] + filter.ids - await db.execute(f"UPDATE nostrrelay.events SET deleted=true WHERE relay_id = ? AND id IN ({ids})", tuple(values)) + _, where, values = build_where_clause(relay_id, filter) + + await db.execute(f"""UPDATE nostrrelay.events SET deleted=true WHERE {" AND ".join(where)}""", tuple(values)) async def create_event_tags( @@ -109,6 +109,7 @@ def build_select_events_query(relay_id:str, filter:NostrFilter): ORDER BY created_at DESC """ + # todo: check range if filter.limit and filter.limit > 0: query += f" LIMIT {filter.limit}" diff --git a/models.py b/models.py index b064d6d..9c16c03 100644 --- a/models.py +++ b/models.py @@ -20,14 +20,15 @@ class NostrRelay(BaseModel): def from_row(cls, row: Row) -> "NostrRelay": return cls(**dict(row)) + class NostrRelayInfo(BaseModel): - name: Optional[str] - description: Optional[str] - pubkey: Optional[str] - contact: Optional[str] = "https://t.me/lnbits" - supported_nips: List[str] = ["NIP01", "NIP09", "NIP11", "NIP15", "NIP20"] - software: Optional[str] = "LNbist" - version: Optional[str] + name: Optional[str] + description: Optional[str] + pubkey: Optional[str] + contact: Optional[str] = "https://t.me/lnbits" + supported_nips: List[str] = ["NIP01", "NIP09", "NIP11", "NIP15", "NIP20"] + software: Optional[str] = "LNbist" + version: Optional[str] class NostrEventType(str, Enum): @@ -80,7 +81,6 @@ class NostrEvent(BaseModel): if not valid_signature: raise ValueError(f"Invalid signature: '{self.sig}' for event '{self.id}'") - def serialize_response(self, subscription_id): return [NostrEventType.EVENT, subscription_id, dict(self)] @@ -128,8 +128,21 @@ class NostrFilter(BaseModel): return True event_tag_values = [t[1] for t in event_tags if t[0] == tag_name] - - common_tags = [event_tag for event_tag in event_tag_values if event_tag in filter_tags] + + common_tags = [ + event_tag for event_tag in event_tag_values if event_tag in filter_tags + ] if len(common_tags) == 0: return False return True + + def is_empty(self): + return ( + len(self.ids) == 0 + and len(self.authors) == 0 + and len(self.kinds) == 0 + and len(self.e) == 0 + and len(self.p) == 0 + and (not self.since) + and (not self.until) + )