diff --git a/nostr/client/client.py b/nostr/client/client.py index 7d54d47..a574f20 100644 --- a/nostr/client/client.py +++ b/nostr/client/client.py @@ -134,5 +134,5 @@ class NostrClient: while self.relay_manager.message_pool.has_events(): event_msg = self.relay_manager.message_pool.get_event() if callback_func: - callback_func(event_msg.event) + callback_func(event_msg) time.sleep(0.1) diff --git a/tasks.py b/tasks.py index 0855bb3..fc6684f 100644 --- a/tasks.py +++ b/tasks.py @@ -4,6 +4,7 @@ import threading from .nostr.client.client import NostrClient from .nostr.event import Event +from .nostr.message_pool import EventMessage from .nostr.key import PublicKey from .nostr.relay_manager import RelayManager @@ -27,7 +28,8 @@ client = NostrClient( # privatekey_hex="211aac75a687ad96cca402406f8147a2726e31c5fc838e22ce30640ca1f3a6fe", # ) -received_event_queue: asyncio.Queue[Event] = asyncio.Queue(0) +received_event_queue: asyncio.Queue[EventMessage] = asyncio.Queue(0) +received_subscription_events: dict[str, list[Event]] = {} from .crud import get_relays @@ -75,9 +77,28 @@ async def subscribe_events(): while not any([r.connected for r in client.relay_manager.relays.values()]): await asyncio.sleep(2) - def callback(event: Event): + def callback(eventMessage: EventMessage): # print(f"From {event.public_key[:3]}..{event.public_key[-3:]}: {event.content}") - asyncio.run(received_event_queue.put(event)) + + if eventMessage.subscription_id in received_subscription_events: + # do not add duplicate events (by signature) + if eventMessage.event.signature in set( + [ + e.signature + for e in received_subscription_events[eventMessage.subscription_id] + ] + ): + return + + received_subscription_events[eventMessage.subscription_id].append( + eventMessage.event + ) + else: + received_subscription_events[eventMessage.subscription_id] = [ + eventMessage.event + ] + + asyncio.run(received_event_queue.put(eventMessage)) t = threading.Thread( target=client.subscribe, diff --git a/views_api.py b/views_api.py index d4d8601..2f0738f 100644 --- a/views_api.py +++ b/views_api.py @@ -2,8 +2,9 @@ from http import HTTPStatus import asyncio import ssl import json -from typing import List -from fastapi import Request, WebSocket +import datetime +from typing import List, Union +from fastapi import Request, WebSocket, WebSocketDisconnect from fastapi.param_functions import Query from fastapi.params import Depends from fastapi.responses import JSONResponse @@ -14,7 +15,7 @@ from loguru import logger from . import nostrclient_ext -from .tasks import client, received_event_queue +from .tasks import client, received_event_queue, received_subscription_events from .crud import get_relays, add_relay, delete_relay from .models import RelayList, Relay, Event, Filter, Filters @@ -80,76 +81,189 @@ async def api_delete_relay(relay: Relay): # type: ignore await delete_relay(relay) -@nostrclient_ext.post("/api/v1/publish") -async def api_post_event(event: Event): - nostr_event = NostrEvent( - content=event.content, - public_key=event.pubkey, - created_at=event.created_at, # type: ignore - kind=event.kind, - tags=event.tags or None, # type: ignore - signature=event.sig, - ) - client.relay_manager.publish_event(nostr_event) +# @nostrclient_ext.post("/api/v1/publish") +# async def api_post_event(event: Event): +# nostr_event = NostrEvent( +# content=event.content, +# public_key=event.pubkey, +# created_at=event.created_at, # type: ignore +# kind=event.kind, +# tags=event.tags or None, # type: ignore +# signature=event.sig, +# ) +# client.relay_manager.publish_event(nostr_event) -@nostrclient_ext.post("/api/v1/filters") -async def api_subscribe(filters: Filters): - nostr_filters = init_filters(filters.__root__) +# @nostrclient_ext.post("/api/v1/filters") +# async def api_subscribe(filters: Filters): +# nostr_filters = init_filters(filters.__root__) - return EventSourceResponse( - event_getter(nostr_filters), - ping=20, - media_type="text/event-stream", - ) +# return EventSourceResponse( +# event_getter(nostr_filters), +# ping=20, +# media_type="text/event-stream", +# ) -@nostrclient_ext.websocket("/api/v1/filters") -async def ws_filter_subscribe(websocket: WebSocket): - await websocket.accept() - while True: - json_data = await websocket.receive_text() - try: - data = json.loads(json_data) - filters = data if isinstance(data, list) else [data] - filters = [Filter.parse_obj(f) for f in filters] - nostr_filters = init_filters(filters) - async for message in event_getter(nostr_filters): - await websocket.send_text(message) - - except Exception as e: - logger.warning(e) - - -def init_filters(filters: List[Filter]): - filter_list = [] +def marshall_nostr_filters(data: Union[dict, list]): + filters = data if isinstance(data, list) else [data] + filters = [Filter.parse_obj(f) for f in filters] + filter_list: list[NostrFilter] = [] for filter in filters: filter_list.append( NostrFilter( - event_ids=filter.ids, + event_ids=filter.ids, # type: ignore kinds=filter.kinds, # type: ignore - authors=filter.authors, - since=filter.since, - until=filter.until, - event_refs=filter.e, - pubkey_refs=filter.p, - limit=filter.limit, + authors=filter.authors, # type: ignore + since=filter.since, # type: ignore + until=filter.until, # type: ignore + event_refs=filter.e, # type: ignore + pubkey_refs=filter.p, # type: ignore + limit=filter.limit, # type: ignore ) ) - - nostr_filters = NostrFilters(filter_list) - subscription_id = urlsafe_short_hash() - client.relay_manager.add_subscription(subscription_id, nostr_filters) - - request = [ClientMessageType.REQUEST, subscription_id] - request.extend(nostr_filters.to_json_array()) - message = json.dumps(request) - client.relay_manager.publish_message(message) - return nostr_filters + return NostrFilters(filter_list) -async def event_getter(nostr_filters): +async def add_nostr_subscription(json_str): + """Parses a (string) request from a client. If it is a subscription (REQ), it will + register the subscription in the nostr client library that we're using so we can + receive the callbacks on it later""" + json_data = json.loads(json_str) + assert len(json_data) + if json_data[0] == "REQ": + subscription_id = json_data[1] + fltr = json_data[2] + filters = marshall_nostr_filters(fltr) + client.relay_manager.add_subscription(subscription_id, filters) + return subscription_id + + +@nostrclient_ext.websocket("/api/v1/relay") +async def ws_relay(websocket: WebSocket): + """Relay multiplexer: one client (per endpoint) <-> multiple relays""" + await websocket.accept() + my_subscriptions: List[str] = [] + connected: bool = True + last_sent: datetime.datetime = datetime.datetime.now() + + async def client_to_nostr(websocket): + """Receives requests / data from the client and forwards it to relays. If the + request was a subscription/filter, registers it with the nostr client lib. + Remembers the subscription id so we can send back responses from the relay to this + client in `nostr_to_client`""" + nonlocal my_subscriptions + nonlocal last_sent + nonlocal connected + while True: + try: + json_str = await websocket.receive_text() + except WebSocketDisconnect: + connected = False + break + # print(json_str) + + # registers a subscription if the input was a REQ request + subscription_id = await add_nostr_subscription(json_str) + if subscription_id: + my_subscriptions.append(subscription_id) + + # publish data + client.relay_manager.publish_message(json_str) + + # update timestamp of last sent data + last_sent = datetime.datetime.now() + + async def nostr_to_client(websocket): + """Sends responses from relays back to the client. Polls the subscriptions of this client + stored in `my_subscriptions`. Then gets all responses for this subscription id from `received_subscription_events` which + is filled in tasks.py. Takes one response after the other and relays it back to the client. Reconstructs + the reponse manually because the nostr client lib we're using can't do it.""" + nonlocal connected + while True and connected: + for s in my_subscriptions: + if s in received_subscription_events: + while len(received_subscription_events[s]): + my_event = received_subscription_events[s].pop(0) + # event.to_message() does not include the subscription ID, we have to add it manually + event_json = { + "id": my_event.id, + "pubkey": my_event.public_key, + "created_at": my_event.created_at, + "kind": my_event.kind, + "tags": my_event.tags, + "content": my_event.content, + "sig": my_event.signature, + } + + # this reconstructs the original response from the relay + event_to_forward = ["EVENT", s, event_json] + # print(json.dumps(event_to_forward)) + + # send data back to client + await websocket.send_text(json.dumps(event_to_forward)) + await asyncio.sleep(0.1) + + asyncio.create_task(client_to_nostr(websocket)) + asyncio.create_task(nostr_to_client(websocket)) + + # we kill this websocket and the subscriptions if no data was sent for + # more than 10 minutes _or_ if the user disconnects and thus `connected==False` while True: - event = await received_event_queue.get() - if nostr_filters.match(event): - yield event.to_message() \ No newline at end of file + await asyncio.sleep(10) + if ( + datetime.datetime.now() - last_sent > datetime.timedelta(minutes=10) + or not connected + ): + break + + +# @nostrclient_ext.websocket("/api/v1/filters") +# async def ws_filter_subscribe(websocket: WebSocket): +# await websocket.accept() +# while True: +# json_data = await websocket.receive_text() +# try: +# data = json.loads(json_data) +# filters = data if isinstance(data, list) else [data] +# filters = [Filter.parse_obj(f) for f in filters] +# nostr_filters = init_filters(filters) +# async for message in event_getter(nostr_filters): +# await websocket.send_text(message) + +# except Exception as e: +# logger.warning(e) + + +# def init_filters(filters: List[Filter]): +# filter_list = [] +# for filter in filters: +# filter_list.append( +# NostrFilter( +# event_ids=filter.ids, # type: ignore +# kinds=filter.kinds, # type: ignore +# authors=filter.authors, # type: ignore +# since=filter.since, # type: ignore +# until=filter.until, # type: ignore +# event_refs=filter.e, # type: ignore +# pubkey_refs=filter.p, # type: ignore +# limit=filter.limit, # type: ignore +# ) +# ) + +# nostr_filters = NostrFilters(filter_list) +# subscription_id = urlsafe_short_hash() +# client.relay_manager.add_subscription(subscription_id, nostr_filters) + +# request = [ClientMessageType.REQUEST, subscription_id] +# request.extend(nostr_filters.to_json_array()) +# message = json.dumps(request) +# client.relay_manager.publish_message(message) +# return nostr_filters + + +# async def event_getter(nostr_filters): +# while True: +# event_message = await received_event_queue.get() +# if nostr_filters.match(event_message.event): +# yield event_message.event.to_message()