refactor: query builder

This commit is contained in:
Vlad Stan 2023-02-06 10:19:47 +02:00
parent 282e65954b
commit 2dfd70c38c
2 changed files with 14 additions and 10 deletions

View file

@ -4,7 +4,7 @@ from typing import Any, Callable, List
from fastapi import WebSocket
from loguru import logger
from .crud import create_event, delete_events, get_event, get_events
from .crud import create_event, mark_events_deleted, get_event, get_events
from .models import NostrEvent, NostrEventType, NostrFilter
@ -96,7 +96,7 @@ class NostrClientConnection:
filter.ids = [t[1] for t in event.tags if t[0] == "e"]
events_to_delete = await get_events("111", filter, False)
ids = [e.id for e in events_to_delete if not e.is_delete_event()]
await delete_events("111", ids)
await mark_events_deleted("111", ids)
async def __handle_request(self, subscription_id: str, filter: NostrFilter) -> List:
filter.subscription_id = subscription_id

20
crud.py
View file

@ -53,7 +53,7 @@ async def get_event(relay_id: str, id: str) -> Optional[NostrEvent]:
event.tags = await get_event_tags(relay_id, id)
return event
async def delete_events(relay_id: str, id_list: List[str] = []):
async def mark_events_deleted(relay_id: str, id_list: List[str] = []):
if len(id_list) == 0:
return None
ids = ",".join(["?"] * len(id_list))
@ -99,22 +99,26 @@ async def get_event_tags(
def build_select_events_query(relay_id:str, filter:NostrFilter):
values, where_clause = build_where_clause(relay_id, filter)
inner_joins, where, values = build_where_clause(relay_id, filter)
query = f"""
SELECT id, pubkey, created_at, kind, content, sig
FROM nostrrelay.events {where_clause}
FROM nostrrelay.events
{" ".join(inner_joins)}
WHERE { " AND ".join(where)}
ORDER BY created_at DESC
"""
if filter.limit and type(filter.limit) == int and filter.limit > 0:
if filter.limit and filter.limit > 0:
query += f" LIMIT {filter.limit}"
return values, query
def build_where_clause(relay_id:str, filter:NostrFilter):
values: List[Any] = [relay_id]
inner_joins = []
where = ["deleted=false", "nostrrelay.events.relay_id = ?"]
values: List[Any] = [relay_id]
if len(filter.e):
values += filter.e
e_s = ",".join(["?"] * len(filter.e))
@ -145,10 +149,10 @@ def build_where_clause(relay_id:str, filter:NostrFilter):
if filter.since:
where.append("reated_at >= ?")
values += [filter.since]
if filter.until:
where.append("created_at <= ?")
values += [filter.until]
query = " ".join(inner_joins)+ " WHERE " + " AND ".join(where)
return values, query
return inner_joins, where, values