feat: allow custom relay IDs

This commit is contained in:
Vlad Stan 2023-02-06 16:58:29 +02:00
parent eedaa52bcf
commit aff949fed5
3 changed files with 41 additions and 27 deletions

View file

@ -34,8 +34,9 @@ class NostrClientManager:
class NostrClientConnection:
broadcast_event: Callable
def __init__(self, websocket: WebSocket):
def __init__(self, relay_id: str, websocket: WebSocket):
self.websocket = websocket
self.relay_id = relay_id
self.filters: List[NostrFilter] = []
async def start(self):
@ -83,9 +84,9 @@ class NostrClientConnection:
e.check_signature()
if e.is_replaceable_event():
await delete_events(
"111", NostrFilter(kinds=[e.kind], authors=[e.pubkey])
self.relay_id, NostrFilter(kinds=[e.kind], authors=[e.pubkey])
)
await create_event("111", e)
await create_event(self.relay_id, e)
await self.broadcast_event(self, e)
if e.is_delete_event():
await self.__handle_delete_event(e)
@ -93,7 +94,7 @@ class NostrClientConnection:
except ValueError:
resp_nip20 += [False, "invalid: wrong event `id` or `sig`"]
except Exception:
event = await get_event("111", e.id)
event = await get_event(self.relay_id, e.id)
# todo: handle NIP20 in detail
resp_nip20 += [event != None, f"error: failed to create event"]
@ -103,15 +104,15 @@ class NostrClientConnection:
# NIP 09
filter = NostrFilter(authors=[event.pubkey])
filter.ids = [t[1] for t in event.tags if t[0] == "e"]
events_to_delete = await get_events("111", filter, False)
events_to_delete = await get_events(self.relay_id, filter, False)
ids = [e.id for e in events_to_delete if not e.is_delete_event()]
await mark_events_deleted("111", NostrFilter(ids=ids))
await mark_events_deleted(self.relay_id, NostrFilter(ids=ids))
async def __handle_request(self, subscription_id: str, filter: NostrFilter) -> List:
filter.subscription_id = subscription_id
self.remove_filter(subscription_id)
self.filters.append(filter)
events = await get_events("111", filter)
events = await get_events(self.relay_id, filter)
serialized_events = [
event.serialize_response(subscription_id) for event in events
]

22
crud.py
View file

@ -46,18 +46,18 @@ async def get_relays(user_id: str) -> List[NostrRelay]:
async def get_public_relay(relay_id: str) -> Optional[dict]:
row = await db.fetchone("""SELECT * FROM nostrrelay.relays WHERE id = ?""", (relay_id,))
if row:
relay = NostrRelay.parse_obj({"id": row["id"], **json.loads(row["meta"])})
if not row:
return None
return {
"id": relay.id,
"name": relay.name,
"description":relay.description,
"pubkey":relay.pubkey,
"contact":relay.contact,
"supported_nips":relay.supported_nips,
}
return None
relay = NostrRelay.from_row(row)
return {
"id": relay.id,
"name": relay.name,
"description":relay.description,
"pubkey":relay.pubkey,
"contact":relay.contact,
"supported_nips":relay.supported_nips,
}
async def delete_relay(user_id: str, relay_id: str):

View file

@ -1,10 +1,11 @@
from http import HTTPStatus
from typing import List, Optional
from pydantic.types import UUID4
from fastapi import Depends, Query, WebSocket
from fastapi.exceptions import HTTPException
from fastapi.responses import JSONResponse
from loguru import logger
from pydantic.types import UUID4
from lnbits.decorators import (
WalletTypeInfo,
@ -16,15 +17,22 @@ from lnbits.helpers import urlsafe_short_hash
from . import nostrrelay_ext
from .client_manager import NostrClientConnection, NostrClientManager
from .crud import create_relay, delete_relay, get_relay, get_relays, update_relay
from .crud import (
create_relay,
delete_relay,
get_public_relay,
get_relay,
get_relays,
update_relay,
)
from .models import NostrRelay
client_manager = NostrClientManager()
@nostrrelay_ext.websocket("/client")
async def websocket_endpoint(websocket: WebSocket):
client = NostrClientConnection(websocket=websocket)
@nostrrelay_ext.websocket("/{relay_id}")
async def websocket_endpoint(relay_id: str, websocket: WebSocket):
client = NostrClientConnection(relay_id=relay_id, websocket=websocket)
client_manager.add_client(client)
try:
await client.start()
@ -33,15 +41,20 @@ async def websocket_endpoint(websocket: WebSocket):
client_manager.remove_client(client)
@nostrrelay_ext.get("/client", status_code=HTTPStatus.OK)
async def api_nostrrelay_info():
@nostrrelay_ext.get("/{relay_id}", status_code=HTTPStatus.OK)
async def api_nostrrelay_info(relay_id: str):
headers = {
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "*",
"Access-Control-Allow-Methods": "GET"
}
info = NostrRelay()
return JSONResponse(content=dict(info), headers=headers)
relay = await get_public_relay(relay_id)
if not relay:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="Relay not found",
)
return JSONResponse(content=relay, headers=headers)