diff --git a/__init__.py b/__init__.py index 019df68..ec6f4b4 100644 --- a/__init__.py +++ b/__init__.py @@ -8,6 +8,8 @@ from lnbits.db import Database from lnbits.helpers import template_renderer from lnbits.tasks import catch_everything_and_restart +from .nostr.client.client import NostrClient as NostrClientLib + db = Database("ext_nostrclient") nostrclient_static_files = [ @@ -22,12 +24,19 @@ nostrclient_ext: APIRouter = APIRouter(prefix="/nostrclient", tags=["nostrclient scheduled_tasks: List[asyncio.Task] = [] +class NostrClient: + def __init__(self): + self.client: NostrClientLib = NostrClientLib(connect=False) + + +nostr = NostrClient() + def nostr_renderer(): return template_renderer(["lnbits/extensions/nostrclient/templates"]) -from .tasks import init_relays, subscribe_events +from .tasks import check_relays, init_relays, subscribe_events from .views import * # noqa from .views_api import * # noqa @@ -38,3 +47,5 @@ def nostrclient_start(): scheduled_tasks.append(task1) task2 = loop.create_task(catch_everything_and_restart(subscribe_events)) scheduled_tasks.append(task2) + task3 = loop.create_task(catch_everything_and_restart(check_relays)) + scheduled_tasks.append(task3) diff --git a/models.py b/models.py index 1456d83..88651fc 100644 --- a/models.py +++ b/models.py @@ -8,12 +8,19 @@ from pydantic import BaseModel, Field from lnbits.helpers import urlsafe_short_hash +class RelayStatus(BaseModel): + num_sent_events: Optional[int] = 0 + num_received_events: Optional[int] = 0 + error_counter: Optional[int] = 0 + error_list: Optional[List] = [] + notice_list: Optional[List] = [] + class Relay(BaseModel): id: Optional[str] = None url: Optional[str] = None connected: Optional[bool] = None connected_string: Optional[str] = None - status: Optional[str] = None + status: Optional[RelayStatus] = None active: Optional[bool] = None ping: Optional[int] = None @@ -59,43 +66,3 @@ class TestMessageResponse(BaseModel): private_key: str public_key: str event_json: str - -# class nostrKeys(BaseModel): -# pubkey: str -# privkey: str - -# class nostrNotes(BaseModel): -# id: str -# pubkey: str -# created_at: str -# kind: int -# tags: str -# content: str -# sig: str - -# class nostrCreateRelays(BaseModel): -# relay: str = Query(None) - -# class nostrCreateConnections(BaseModel): -# pubkey: str = Query(None) -# relayid: str = Query(None) - -# class nostrRelays(BaseModel): -# id: Optional[str] -# relay: Optional[str] -# status: Optional[bool] = False - - -# class nostrRelaySetList(BaseModel): -# allowlist: Optional[str] -# denylist: Optional[str] - -# class nostrConnections(BaseModel): -# id: str -# pubkey: Optional[str] -# relayid: Optional[str] - -# class nostrSubscriptions(BaseModel): -# id: str -# userPubkey: Optional[str] -# subscribedPubkey: Optional[str] diff --git a/nostr/bech32.py b/nostr/bech32.py index b068de7..61a92c4 100644 --- a/nostr/bech32.py +++ b/nostr/bech32.py @@ -23,6 +23,7 @@ from enum import Enum + class Encoding(Enum): """Enumeration type to list the various supported encodings.""" BECH32 = 1 diff --git a/nostr/client/cbc.py b/nostr/client/cbc.py deleted file mode 100644 index a41dbc0..0000000 --- a/nostr/client/cbc.py +++ /dev/null @@ -1,41 +0,0 @@ - -from Cryptodome import Random -from Cryptodome.Cipher import AES - -plain_text = "This is the text to encrypts" - -# encrypted = "7mH9jq3K9xNfWqIyu9gNpUz8qBvGwsrDJ+ACExdV1DvGgY8q39dkxVKeXD7LWCDrPnoD/ZFHJMRMis8v9lwHfNgJut8EVTMuJJi8oTgJevOBXl+E+bJPwej9hY3k20rgCQistNRtGHUzdWyOv7S1tg==".encode() -# iv = "GzDzqOVShWu3Pl2313FBpQ==".encode() - -key = bytes.fromhex("3aa925cb69eb613e2928f8a18279c78b1dca04541dfd064df2eda66b59880795") - -BLOCK_SIZE = 16 - -class AESCipher(object): - """This class is compatible with crypto.createCipheriv('aes-256-cbc') - - """ - def __init__(self, key=None): - self.key = key - - def pad(self, data): - length = BLOCK_SIZE - (len(data) % BLOCK_SIZE) - return data + (chr(length) * length).encode() - - def unpad(self, data): - return data[: -(data[-1] if type(data[-1]) == int else ord(data[-1]))] - - def encrypt(self, plain_text): - cipher = AES.new(self.key, AES.MODE_CBC) - b = plain_text.encode("UTF-8") - return cipher.iv, cipher.encrypt(self.pad(b)) - - def decrypt(self, iv, enc_text): - cipher = AES.new(self.key, AES.MODE_CBC, iv=iv) - return self.unpad(cipher.decrypt(enc_text).decode("UTF-8")) - -if __name__ == "__main__": - aes = AESCipher(key=key) - iv, enc_text = aes.encrypt(plain_text) - dec_text = aes.decrypt(iv, enc_text) - print(dec_text) \ No newline at end of file diff --git a/nostr/client/client.py b/nostr/client/client.py index 6e70f71..66d722c 100644 --- a/nostr/client/client.py +++ b/nostr/client/client.py @@ -1,38 +1,14 @@ -from typing import * -import ssl import time -import json -import os -import base64 +from typing import List -from ..event import Event from ..relay_manager import RelayManager -from ..message_type import ClientMessageType -from ..key import PrivateKey, PublicKey - -from ..filter import Filter, Filters -from ..event import Event, EventKind, EncryptedDirectMessage -from ..relay_manager import RelayManager -from ..message_type import ClientMessageType - -# from aes import AESCipher -from . import cbc class NostrClient: - relays = [ - "wss://nostr-pub.wellorder.net", - "wss://nostr.zebedee.cloud", - "wss://nodestr.fmt.wiz.biz", - "wss://nostr.oxtr.dev", - ] # ["wss://nostr.oxtr.dev"] # ["wss://relay.nostr.info"] "wss://nostr-pub.wellorder.net" "ws://91.237.88.218:2700", "wss://nostrrr.bublina.eu.org", ""wss://nostr-relay.freeberty.net"", , "wss://nostr.oxtr.dev", "wss://relay.nostr.info", "wss://nostr-pub.wellorder.net" , "wss://relayer.fiatjaf.com", "wss://nodestr.fmt.wiz.biz/", "wss://no.str.cr" + relays = [ ] relay_manager = RelayManager() - private_key: PrivateKey - public_key: PublicKey - - def __init__(self, privatekey_hex: str = "", relays: List[str] = [], connect=True): - self.generate_keys(privatekey_hex) + def __init__(self, relays: List[str] = [], connect=True): if len(relays): self.relays = relays if connect: @@ -41,94 +17,10 @@ class NostrClient: def connect(self): for relay in self.relays: self.relay_manager.add_relay(relay) - self.relay_manager.open_connections( - {"cert_reqs": ssl.CERT_NONE} - ) # NOTE: This disables ssl certificate verification def close(self): self.relay_manager.close_connections() - def generate_keys(self, privatekey_hex: str = None): - pk = bytes.fromhex(privatekey_hex) if privatekey_hex else None - self.private_key = PrivateKey(pk) - self.public_key = self.private_key.public_key - - def post(self, message: str): - event = Event(message, self.public_key.hex(), kind=EventKind.TEXT_NOTE) - self.private_key.sign_event(event) - event_json = event.to_message() - # print("Publishing message:") - # print(event_json) - self.relay_manager.publish_message(event_json) - - def get_post( - self, sender_publickey: PublicKey = None, callback_func=None, filter_kwargs={} - ): - filter = Filter( - authors=[sender_publickey.hex()] if sender_publickey else None, - kinds=[EventKind.TEXT_NOTE], - **filter_kwargs, - ) - filters = Filters([filter]) - subscription_id = os.urandom(4).hex() - self.relay_manager.add_subscription(subscription_id, filters) - - request = [ClientMessageType.REQUEST, subscription_id] - request.extend(filters.to_json_array()) - message = json.dumps(request) - self.relay_manager.publish_message(message) - - while True: - while self.relay_manager.message_pool.has_events(): - event_msg = self.relay_manager.message_pool.get_event() - if callback_func: - callback_func(event_msg.event) - time.sleep(0.1) - - def dm(self, message: str, to_pubkey: PublicKey): - dm = EncryptedDirectMessage( - recipient_pubkey=to_pubkey.hex(), cleartext_content=message - ) - self.private_key.sign_event(dm) - self.relay_manager.publish_event(dm) - - def get_dm(self, sender_publickey: PublicKey, callback_func=None): - filters = Filters( - [ - Filter( - kinds=[EventKind.ENCRYPTED_DIRECT_MESSAGE], - pubkey_refs=[sender_publickey.hex()], - ) - ] - ) - subscription_id = os.urandom(4).hex() - self.relay_manager.add_subscription(subscription_id, filters) - - request = [ClientMessageType.REQUEST, subscription_id] - request.extend(filters.to_json_array()) - message = json.dumps(request) - self.relay_manager.publish_message(message) - - while True: - while self.relay_manager.message_pool.has_events(): - event_msg = self.relay_manager.message_pool.get_event() - if "?iv=" in event_msg.event.content: - try: - shared_secret = self.private_key.compute_shared_secret( - event_msg.event.public_key - ) - aes = cbc.AESCipher(key=shared_secret) - enc_text_b64, iv_b64 = event_msg.event.content.split("?iv=") - iv = base64.decodebytes(iv_b64.encode("utf-8")) - enc_text = base64.decodebytes(enc_text_b64.encode("utf-8")) - dec_text = aes.decrypt(iv, enc_text) - if callback_func: - callback_func(event_msg.event, dec_text) - except: - pass - break - time.sleep(0.1) - def subscribe( self, callback_events_func=None, diff --git a/nostr/event.py b/nostr/event.py index b903e0e..65b187d 100644 --- a/nostr/event.py +++ b/nostr/event.py @@ -1,10 +1,11 @@ -import time import json +import time from dataclasses import dataclass, field from enum import IntEnum -from typing import List -from secp256k1 import PublicKey from hashlib import sha256 +from typing import List + +from secp256k1 import PublicKey from .message_type import ClientMessageType diff --git a/nostr/key.py b/nostr/key.py index d34697f..8089e11 100644 --- a/nostr/key.py +++ b/nostr/key.py @@ -1,14 +1,15 @@ -import secrets import base64 -import secp256k1 -from cffi import FFI -from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes -from cryptography.hazmat.primitives import padding +import secrets from hashlib import sha256 +import secp256k1 +from cffi import FFI +from cryptography.hazmat.primitives import padding +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes + +from . import bech32 from .delegation import Delegation from .event import EncryptedDirectMessage, Event, EventKind -from . import bech32 class PublicKey: diff --git a/nostr/message_pool.py b/nostr/message_pool.py index d364cf2..02f7fd4 100644 --- a/nostr/message_pool.py +++ b/nostr/message_pool.py @@ -1,8 +1,9 @@ import json from queue import Queue from threading import Lock -from .message_type import RelayMessageType + from .event import Event +from .message_type import RelayMessageType class EventMessage: @@ -68,10 +69,19 @@ class MessagePool: e["sig"], ) with self.lock: - if not event.id in self._unique_events: - self.events.put(EventMessage(event, subscription_id, url)) - self._unique_events.add(event.id) + if not f"{subscription_id}_{event.id}" in self._unique_events: + self._accept_event(EventMessage(event, subscription_id, url)) elif message_type == RelayMessageType.NOTICE: self.notices.put(NoticeMessage(message_json[1], url)) elif message_type == RelayMessageType.END_OF_STORED_EVENTS: self.eose_notices.put(EndOfStoredEventsMessage(message_json[1], url)) + + def _accept_event(self, event_message: EventMessage): + """ + Event uniqueness is considered per `subscription_id`. + The `subscription_id` is rewritten to be unique and it is the same accross relays. + The same event can come from different subscriptions (from the same client or from different ones). + Clients that have joined later should receive older events. + """ + self.events.put(event_message) + self._unique_events.add(f"{event_message.subscription_id}_{event_message.event.id}") \ No newline at end of file diff --git a/nostr/message_type.py b/nostr/message_type.py index 3f5206b..d37cdfd 100644 --- a/nostr/message_type.py +++ b/nostr/message_type.py @@ -3,13 +3,20 @@ class ClientMessageType: REQUEST = "REQ" CLOSE = "CLOSE" + class RelayMessageType: EVENT = "EVENT" NOTICE = "NOTICE" END_OF_STORED_EVENTS = "EOSE" + COMMAND_RESULT = "OK" @staticmethod def is_valid(type: str) -> bool: - if type == RelayMessageType.EVENT or type == RelayMessageType.NOTICE or type == RelayMessageType.END_OF_STORED_EVENTS: + if ( + type == RelayMessageType.EVENT + or type == RelayMessageType.NOTICE + or type == RelayMessageType.END_OF_STORED_EVENTS + or type == RelayMessageType.COMMAND_RESULT + ): return True - return False \ No newline at end of file + return False diff --git a/nostr/pow.py b/nostr/pow.py deleted file mode 100644 index e006288..0000000 --- a/nostr/pow.py +++ /dev/null @@ -1,54 +0,0 @@ -import time -from .event import Event -from .key import PrivateKey - -def zero_bits(b: int) -> int: - n = 0 - - if b == 0: - return 8 - - while b >> 1: - b = b >> 1 - n += 1 - - return 7 - n - -def count_leading_zero_bits(hex_str: str) -> int: - total = 0 - for i in range(0, len(hex_str) - 2, 2): - bits = zero_bits(int(hex_str[i:i+2], 16)) - total += bits - - if bits != 8: - break - - return total - -def mine_event(content: str, difficulty: int, public_key: str, kind: int, tags: list=[]) -> Event: - all_tags = [["nonce", "1", str(difficulty)]] - all_tags.extend(tags) - - created_at = int(time.time()) - event_id = Event.compute_id(public_key, created_at, kind, all_tags, content) - num_leading_zero_bits = count_leading_zero_bits(event_id) - - attempts = 1 - while num_leading_zero_bits < difficulty: - attempts += 1 - all_tags[0][1] = str(attempts) - created_at = int(time.time()) - event_id = Event.compute_id(public_key, created_at, kind, all_tags, content) - num_leading_zero_bits = count_leading_zero_bits(event_id) - - return Event(public_key, content, created_at, kind, all_tags, event_id) - -def mine_key(difficulty: int) -> PrivateKey: - sk = PrivateKey() - num_leading_zero_bits = count_leading_zero_bits(sk.public_key.hex()) - - while num_leading_zero_bits < difficulty: - sk = PrivateKey() - num_leading_zero_bits = count_leading_zero_bits(sk.public_key.hex()) - - return sk diff --git a/nostr/relay.py b/nostr/relay.py index 7fb4baa..8b081a3 100644 --- a/nostr/relay.py +++ b/nostr/relay.py @@ -2,7 +2,11 @@ import json import time from queue import Queue from threading import Lock +from typing import List + +from loguru import logger from websocket import WebSocketApp + from .event import Event from .filter import Filters from .message_pool import MessagePool @@ -36,6 +40,9 @@ class Relay: self.shutdown: bool = False self.error_counter: int = 0 self.error_threshold: int = 100 + self.error_list: List[str] = [] + self.notice_list: List[str] = [] + self.last_error_date: int = 0 self.num_received_events: int = 0 self.num_sent_events: int = 0 self.num_subscriptions: int = 0 @@ -67,17 +74,12 @@ class Relay: def close(self): self.ws.close() + self.connected = False self.shutdown = True - def check_reconnect(self): - try: - self.close() - except: - pass - self.connected = False - if self.reconnect: - time.sleep(self.error_counter**2) - self.connect(self.ssl_options, self.proxy) + @property + def error_threshold_reached(self): + return self.error_threshold and self.error_counter >= self.error_threshold @property def ping(self): @@ -87,15 +89,22 @@ class Relay: def publish(self, message: str): self.queue.put(message) - def queue_worker(self, shutdown): + def publish_subscriptions(self): + for _, subscription in self.subscriptions.items(): + s = subscription.to_json_object() + json_str = json.dumps(["REQ", s["id"], s["filters"][0]]) + self.publish(json_str) + + def queue_worker(self): while True: if self.connected: try: message = self.queue.get(timeout=1) self.num_sent_events += 1 self.ws.send(message) - except: - if shutdown(): + except Exception as e: + if self.shutdown: + logger.warning(f"Closing queue worker for '{self.url}'.") break else: time.sleep(0.1) @@ -107,11 +116,7 @@ class Relay: def close_subscription(self, id: str) -> None: with self.lock: self.subscriptions.pop(id) - - def update_subscription(self, id: str, filters: Filters) -> None: - with self.lock: - subscription = self.subscriptions[id] - subscription.filters = filters + self.publish(json.dumps(["CLOSE", id])) def to_json_object(self) -> dict: return { @@ -123,31 +128,32 @@ class Relay: ], } - def _on_open(self, class_obj): + def add_notice(self, notice: str): + self.notice_list = ([notice] + self.notice_list)[:20] + + def _on_open(self, _): + logger.info(f"Connected to relay: '{self.url}'.") self.connected = True - pass + + def _on_close(self, _, status_code, message): + logger.warning(f"Connection to relay {self.url} closed. Status: '{status_code}'. Message: '{message}'.") + self.close() - def _on_close(self, class_obj, status_code, message): - self.connected = False - if self.error_threshold and self.error_counter > self.error_threshold: - pass - else: - self.check_reconnect() - pass - - def _on_message(self, class_obj, message: str): + def _on_message(self, _, message: str): if self._is_valid_message(message): self.num_received_events += 1 self.message_pool.add_message(message, self.url) - def _on_error(self, class_obj, error): + def _on_error(self, _, error): + logger.warning(f"Relay error: '{str(error)}'") + self._append_error_message(str(error)) self.connected = False self.error_counter += 1 - def _on_ping(self, class_obj, message): + def _on_ping(self, *_): return - def _on_pong(self, class_obj, message): + def _on_pong(self, *_): return def _is_valid_message(self, message: str) -> bool: @@ -157,33 +163,58 @@ class Relay: message_json = json.loads(message) message_type = message_json[0] + if not RelayMessageType.is_valid(message_type): return False + if message_type == RelayMessageType.EVENT: - if not len(message_json) == 3: - return False - - subscription_id = message_json[1] - with self.lock: - if subscription_id not in self.subscriptions: - return False - - e = message_json[2] - event = Event( - e["content"], - e["pubkey"], - e["created_at"], - e["kind"], - e["tags"], - e["sig"], - ) - if not event.verify(): - return False - - with self.lock: - subscription = self.subscriptions[subscription_id] - - if subscription.filters and not subscription.filters.match(event): - return False + return self._is_valid_event_message(message_json) + + if message_type == RelayMessageType.COMMAND_RESULT: + return self._is_valid_command_result_message(message, message_json) return True + + def _is_valid_event_message(self, message_json): + if not len(message_json) == 3: + return False + + subscription_id = message_json[1] + with self.lock: + if subscription_id not in self.subscriptions: + return False + + e = message_json[2] + event = Event( + e["content"], + e["pubkey"], + e["created_at"], + e["kind"], + e["tags"], + e["sig"], + ) + if not event.verify(): + return False + + with self.lock: + subscription = self.subscriptions[subscription_id] + + if subscription.filters and not subscription.filters.match(event): + return False + + return True + + def _is_valid_command_result_message(self, message, message_json): + if not len(message_json) < 3: + return False + + if message_json[2] != True: + logger.warning(f"Relay '{self.url}' negative command result: '{message}'") + self._append_error_message(message) + return False + + return True + + def _append_error_message(self, message): + self.error_list = ([message] + self.error_list)[:20] + self.last_error_date = int(time.time()) \ No newline at end of file diff --git a/nostr/relay_manager.py b/nostr/relay_manager.py index a698a33..b2df735 100644 --- a/nostr/relay_manager.py +++ b/nostr/relay_manager.py @@ -1,11 +1,14 @@ -import json -import threading -from .event import Event +import ssl +import threading +import time + +from loguru import logger + from .filter import Filters -from .message_pool import MessagePool -from .message_type import ClientMessageType +from .message_pool import MessagePool, NoticeMessage from .relay import Relay, RelayPolicy +from .subscription import Subscription class RelayException(Exception): @@ -18,45 +21,55 @@ class RelayManager: self.threads: dict[str, threading.Thread] = {} self.queue_threads: dict[str, threading.Thread] = {} self.message_pool = MessagePool() + self._cached_subscriptions: dict[str, Subscription] = {} + self._subscriptions_lock = threading.Lock() + + def add_relay(self, url: str, read: bool = True, write: bool = True) -> Relay: + if url in list(self.relays.keys()): + return + + with self._subscriptions_lock: + subscriptions = self._cached_subscriptions.copy() - def add_relay( - self, url: str, read: bool = True, write: bool = True, subscriptions={} - ): policy = RelayPolicy(read, write) relay = Relay(url, policy, self.message_pool, subscriptions) self.relays[url] = relay + self._open_connection( + relay, + {"cert_reqs": ssl.CERT_NONE} + ) # NOTE: This disables ssl certificate verification + + relay.publish_subscriptions() + return relay + def remove_relay(self, url: str): - self.relays[url].close() - self.relays.pop(url) self.threads[url].join(timeout=1) self.threads.pop(url) + self.queue_threads[url].join(timeout=1) + self.queue_threads.pop(url) + self.relays[url].close() + self.relays.pop(url) def add_subscription(self, id: str, filters: Filters): + with self._subscriptions_lock: + self._cached_subscriptions[id] = Subscription(id, filters) + for relay in self.relays.values(): relay.add_subscription(id, filters) def close_subscription(self, id: str): + with self._subscriptions_lock: + self._cached_subscriptions.pop(id) + for relay in self.relays.values(): relay.close_subscription(id) - def open_connections(self, ssl_options: dict = None, proxy: dict = None): - for relay in self.relays.values(): - self.threads[relay.url] = threading.Thread( - target=relay.connect, - args=(ssl_options, proxy), - name=f"{relay.url}-thread", - daemon=True, - ) - self.threads[relay.url].start() + def check_and_restart_relays(self): + stopped_relays = [r for r in self.relays.values() if r.shutdown] + for relay in stopped_relays: + self._restart_relay(relay) - self.queue_threads[relay.url] = threading.Thread( - target=relay.queue_worker, - args=(lambda: relay.shutdown,), - name=f"{relay.url}-queue", - daemon=True, - ) - self.queue_threads[relay.url].start() def close_connections(self): for relay in self.relays.values(): @@ -67,13 +80,38 @@ class RelayManager: if relay.policy.should_write: relay.publish(message) - def publish_event(self, event: Event): - """Verifies that the Event is publishable before submitting it to relays""" - if event.signature is None: - raise RelayException(f"Could not publish {event.id}: must be signed") + def handle_notice(self, notice: NoticeMessage): + relay = next((r for r in self.relays.values() if r.url == notice.url)) + if relay: + relay.add_notice(notice.content) - if not event.verify(): - raise RelayException( - f"Could not publish {event.id}: failed to verify signature {event.signature}" - ) - self.publish_message(event.to_message()) + def _open_connection(self, relay: Relay, ssl_options: dict = None, proxy: dict = None): + self.threads[relay.url] = threading.Thread( + target=relay.connect, + args=(ssl_options, proxy), + name=f"{relay.url}-thread", + daemon=True, + ) + self.threads[relay.url].start() + + self.queue_threads[relay.url] = threading.Thread( + target=relay.queue_worker, + name=f"{relay.url}-queue", + daemon=True, + ) + self.queue_threads[relay.url].start() + + def _restart_relay(self, relay: Relay): + if relay.error_threshold_reached: + time_since_last_error = time.time() - relay.last_error_date + if time_since_last_error < 60 * 60 * 2: # last day + return + relay.error_counter = 0 + relay.error_list = [] + + logger.info(f"Restarting connection to relay '{relay.url}'") + + self.remove_relay(relay.url) + new_relay = self.add_relay(relay.url) + new_relay.error_counter = relay.error_counter + new_relay.error_list = relay.error_list \ No newline at end of file diff --git a/nostr/subscription.py b/nostr/subscription.py index 7afba20..76da0af 100644 --- a/nostr/subscription.py +++ b/nostr/subscription.py @@ -1,5 +1,6 @@ from .filter import Filters + class Subscription: def __init__(self, id: str, filters: Filters=None) -> None: self.id = id diff --git a/router.py b/router.py new file mode 100644 index 0000000..e85653c --- /dev/null +++ b/router.py @@ -0,0 +1,190 @@ +import asyncio +import json +from typing import List, Union + +from fastapi import WebSocketDisconnect +from loguru import logger + +from lnbits.helpers import urlsafe_short_hash + +from . import nostr +from .models import Event, Filter +from .nostr.filter import Filter as NostrFilter +from .nostr.filter import Filters as NostrFilters +from .nostr.message_pool import EndOfStoredEventsMessage, NoticeMessage + + +class NostrRouter: + + received_subscription_events: dict[str, list[Event]] = {} + received_subscription_notices: list[NoticeMessage] = [] + received_subscription_eosenotices: dict[str, EndOfStoredEventsMessage] = {} + + def __init__(self, websocket): + self.subscriptions: List[str] = [] + self.connected: bool = True + self.websocket = websocket + self.tasks: List[asyncio.Task] = [] + self.original_subscription_ids = {} + + async def client_to_nostr(self): + """Receives requests / data from the client and forwards it to relays. If the + request was a subscription/filter, registers it with the nostr client lib. + Remembers the subscription id so we can send back responses from the relay to this + client in `nostr_to_client`""" + while True: + try: + json_str = await self.websocket.receive_text() + except WebSocketDisconnect: + self.connected = False + break + + try: + await self._handle_client_to_nostr(json_str) + except Exception as e: + logger.debug(f"Failed to handle client message: '{str(e)}'.") + + + async def nostr_to_client(self): + """Sends responses from relays back to the client. Polls the subscriptions of this client + stored in `my_subscriptions`. Then gets all responses for this subscription id from `received_subscription_events` which + is filled in tasks.py. Takes one response after the other and relays it back to the client. Reconstructs + the reponse manually because the nostr client lib we're using can't do it. Reconstructs the original subscription id + that we had previously rewritten in order to avoid collisions when multiple clients use the same id. + """ + while True and self.connected: + try: + await self._handle_subscriptions() + self._handle_notices() + except Exception as e: + logger.debug(f"Failed to handle response for client: '{str(e)}'.") + await asyncio.sleep(0.1) + + + async def start(self): + self.tasks.append(asyncio.create_task(self.client_to_nostr())) + self.tasks.append(asyncio.create_task(self.nostr_to_client())) + + async def stop(self): + for t in self.tasks: + try: + t.cancel() + except: + pass + + for s in self.subscriptions: + try: + nostr.client.relay_manager.close_subscription(s) + except: + pass + self.connected = False + + async def _handle_subscriptions(self): + for s in self.subscriptions: + if s in NostrRouter.received_subscription_events: + await self._handle_received_subscription_events(s) + if s in NostrRouter.received_subscription_eosenotices: + await self._handle_received_subscription_eosenotices(s) + + + + async def _handle_received_subscription_eosenotices(self, s): + s_original = self.original_subscription_ids[s] + event_to_forward = ["EOSE", s_original] + del NostrRouter.received_subscription_eosenotices[s] + + await self.websocket.send_text(json.dumps(event_to_forward)) + + async def _handle_received_subscription_events(self, s): + while len(NostrRouter.received_subscription_events[s]): + my_event = NostrRouter.received_subscription_events[s].pop(0) + # event.to_message() does not include the subscription ID, we have to add it manually + event_json = { + "id": my_event.id, + "pubkey": my_event.public_key, + "created_at": my_event.created_at, + "kind": my_event.kind, + "tags": my_event.tags, + "content": my_event.content, + "sig": my_event.signature, + } + + # this reconstructs the original response from the relay + # reconstruct original subscription id + s_original = self.original_subscription_ids[s] + event_to_forward = ["EVENT", s_original, event_json] + await self.websocket.send_text(json.dumps(event_to_forward)) + + def _handle_notices(self): + while len(NostrRouter.received_subscription_notices): + my_event = NostrRouter.received_subscription_notices.pop(0) + # note: we don't send it to the user because we don't know who should receive it + logger.info(f"Relay ('{my_event.url}') notice: '{my_event.content}']") + nostr.client.relay_manager.handle_notice(my_event) + + + + def _marshall_nostr_filters(self, data: Union[dict, list]): + filters = data if isinstance(data, list) else [data] + filters = [Filter.parse_obj(f) for f in filters] + filter_list: list[NostrFilter] = [] + for filter in filters: + filter_list.append( + NostrFilter( + event_ids=filter.ids, # type: ignore + kinds=filter.kinds, # type: ignore + authors=filter.authors, # type: ignore + since=filter.since, # type: ignore + until=filter.until, # type: ignore + event_refs=filter.e, # type: ignore + pubkey_refs=filter.p, # type: ignore + limit=filter.limit, # type: ignore + ) + ) + return NostrFilters(filter_list) + + async def _handle_client_to_nostr(self, json_str): + """Parses a (string) request from a client. If it is a subscription (REQ) or a CLOSE, it will + register the subscription in the nostr client library that we're using so we can + receive the callbacks on it later. Will rewrite the subscription id since we expect + multiple clients to use the router and want to avoid subscription id collisions + """ + + json_data = json.loads(json_str) + assert len(json_data) + + + if json_data[0] == "REQ": + self._handle_client_req(json_data) + return + + if json_data[0] == "CLOSE": + self._handle_client_close(json_data[1]) + return + + if json_data[0] == "EVENT": + nostr.client.relay_manager.publish_message(json_str) + return + + def _handle_client_req(self, json_data): + subscription_id = json_data[1] + subscription_id_rewritten = urlsafe_short_hash() + self.original_subscription_ids[subscription_id_rewritten] = subscription_id + fltr = json_data[2] + filters = self._marshall_nostr_filters(fltr) + + nostr.client.relay_manager.add_subscription( + subscription_id_rewritten, filters + ) + request_rewritten = json.dumps([json_data[0], subscription_id_rewritten, fltr]) + + self.subscriptions.append(subscription_id_rewritten) + nostr.client.relay_manager.publish_message(request_rewritten) + + def _handle_client_close(self, subscription_id): + subscription_id_rewritten = next((k for k, v in self.original_subscription_ids.items() if v == subscription_id), None) + if subscription_id_rewritten: + self.original_subscription_ids.pop(subscription_id_rewritten) + nostr.client.relay_manager.close_subscription(subscription_id_rewritten) + else: + logger.debug(f"Failed to unsubscribe from '{subscription_id}.'") diff --git a/services.py b/services.py deleted file mode 100644 index 82f6578..0000000 --- a/services.py +++ /dev/null @@ -1,163 +0,0 @@ -import asyncio -import json -from typing import List, Union - -from fastapi import WebSocket, WebSocketDisconnect -from loguru import logger - -from lnbits.helpers import urlsafe_short_hash - -from .models import Event, Filter, Filters, Relay, RelayList -from .nostr.client.client import NostrClient as NostrClientLib -from .nostr.event import Event as NostrEvent -from .nostr.filter import Filter as NostrFilter -from .nostr.filter import Filters as NostrFilters -from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage - -received_subscription_events: dict[str, list[Event]] = {} -received_subscription_notices: list[NoticeMessage] = [] -received_subscription_eosenotices: dict[str, EndOfStoredEventsMessage] = {} - - -class NostrClient: - def __init__(self): - self.client: NostrClientLib = NostrClientLib(connect=False) - - -nostr = NostrClient() - - -class NostrRouter: - def __init__(self, websocket): - self.subscriptions: List[str] = [] - self.connected: bool = True - self.websocket = websocket - self.tasks: List[asyncio.Task] = [] - self.oridinal_subscription_ids = {} - - async def client_to_nostr(self): - """Receives requests / data from the client and forwards it to relays. If the - request was a subscription/filter, registers it with the nostr client lib. - Remembers the subscription id so we can send back responses from the relay to this - client in `nostr_to_client`""" - while True: - try: - json_str = await self.websocket.receive_text() - except WebSocketDisconnect: - self.connected = False - break - - # registers a subscription if the input was a REQ request - subscription_id, json_str_rewritten = await self._add_nostr_subscription( - json_str - ) - - if subscription_id and json_str_rewritten: - self.subscriptions.append(subscription_id) - json_str = json_str_rewritten - - # publish data - nostr.client.relay_manager.publish_message(json_str) - - async def nostr_to_client(self): - """Sends responses from relays back to the client. Polls the subscriptions of this client - stored in `my_subscriptions`. Then gets all responses for this subscription id from `received_subscription_events` which - is filled in tasks.py. Takes one response after the other and relays it back to the client. Reconstructs - the reponse manually because the nostr client lib we're using can't do it. Reconstructs the original subscription id - that we had previously rewritten in order to avoid collisions when multiple clients use the same id. - """ - while True and self.connected: - for s in self.subscriptions: - if s in received_subscription_events: - while len(received_subscription_events[s]): - my_event = received_subscription_events[s].pop(0) - # event.to_message() does not include the subscription ID, we have to add it manually - event_json = { - "id": my_event.id, - "pubkey": my_event.public_key, - "created_at": my_event.created_at, - "kind": my_event.kind, - "tags": my_event.tags, - "content": my_event.content, - "sig": my_event.signature, - } - - # this reconstructs the original response from the relay - # reconstruct original subscription id - s_original = self.oridinal_subscription_ids[s] - event_to_forward = ["EVENT", s_original, event_json] - - # print("Event to forward") - # print(json.dumps(event_to_forward)) - - # send data back to client - await self.websocket.send_text(json.dumps(event_to_forward)) - if s in received_subscription_eosenotices: - my_event = received_subscription_eosenotices[s] - s_original = self.oridinal_subscription_ids[s] - event_to_forward = ["EOSE", s_original] - del received_subscription_eosenotices[s] - # send data back to client - # print("Sending EOSE", event_to_forward) - await self.websocket.send_text(json.dumps(event_to_forward)) - - # if s in received_subscription_notices: - while len(received_subscription_notices): - my_event = received_subscription_notices.pop(0) - event_to_forward = ["NOTICE", my_event.content] - # send data back to client - logger.debug("Nostrclient: Received notice", event_to_forward[1]) - # note: we don't send it to the user because we don't know who should receive it - # await self.websocket.send_text(json.dumps(event_to_forward)) - await asyncio.sleep(0.1) - - async def start(self): - self.tasks.append(asyncio.create_task(self.client_to_nostr())) - self.tasks.append(asyncio.create_task(self.nostr_to_client())) - - async def stop(self): - for t in self.tasks: - t.cancel() - self.connected = False - - def _marshall_nostr_filters(self, data: Union[dict, list]): - filters = data if isinstance(data, list) else [data] - filters = [Filter.parse_obj(f) for f in filters] - filter_list: list[NostrFilter] = [] - for filter in filters: - filter_list.append( - NostrFilter( - event_ids=filter.ids, # type: ignore - kinds=filter.kinds, # type: ignore - authors=filter.authors, # type: ignore - since=filter.since, # type: ignore - until=filter.until, # type: ignore - event_refs=filter.e, # type: ignore - pubkey_refs=filter.p, # type: ignore - limit=filter.limit, # type: ignore - ) - ) - return NostrFilters(filter_list) - - async def _add_nostr_subscription(self, json_str): - """Parses a (string) request from a client. If it is a subscription (REQ) or a CLOSE, it will - register the subscription in the nostr client library that we're using so we can - receive the callbacks on it later. Will rewrite the subscription id since we expect - multiple clients to use the router and want to avoid subscription id collisions - """ - json_data = json.loads(json_str) - assert len(json_data) - if json_data[0] in ["REQ", "CLOSE"]: - subscription_id = json_data[1] - subscription_id_rewritten = urlsafe_short_hash() - self.oridinal_subscription_ids[subscription_id_rewritten] = subscription_id - fltr = json_data[2] - filters = self._marshall_nostr_filters(fltr) - nostr.client.relay_manager.add_subscription( - subscription_id_rewritten, filters - ) - request_rewritten = json.dumps( - [json_data[0], subscription_id_rewritten, fltr] - ) - return subscription_id_rewritten, request_rewritten - return None, None diff --git a/tasks.py b/tasks.py index beff9db..05057e7 100644 --- a/tasks.py +++ b/tasks.py @@ -1,82 +1,66 @@ import asyncio -import json -import ssl import threading +from loguru import logger + +from . import nostr from .crud import get_relays -from .nostr.event import Event -from .nostr.key import PublicKey from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage -from .nostr.relay_manager import RelayManager -from .services import ( - nostr, - received_subscription_eosenotices, - received_subscription_events, - received_subscription_notices, -) +from .router import NostrRouter, nostr async def init_relays(): - # we save any subscriptions teporarily to re-add them after reinitializing the client - subscriptions = {} - for relay in nostr.client.relay_manager.relays.values(): - # relay.add_subscription(id, filters) - for subscription_id, filters in relay.subscriptions.items(): - subscriptions[subscription_id] = filters - # reinitialize the entire client nostr.__init__() # get relays from db relays = await get_relays() # set relays and connect to them nostr.client.relays = list(set([r.url for r in relays.__root__ if r.url])) - nostr.client.connect() + await nostr.client.connect() - await asyncio.sleep(2) - # re-add subscriptions - for subscription_id, subscription in subscriptions.items(): - nostr.client.relay_manager.add_subscription( - subscription_id, subscription.filters - ) - s = subscription.to_json_object() - json_str = json.dumps(["REQ", s["id"], s["filters"][0]]) - nostr.client.relay_manager.publish_message(json_str) - return +async def check_relays(): + """ Check relays that have been disconnected """ + while True: + try: + await asyncio.sleep(20) + nostr.client.relay_manager.check_and_restart_relays() + except Exception as e: + logger.warning(f"Cannot restart relays: '{str(e)}'.") + async def subscribe_events(): while not any([r.connected for r in nostr.client.relay_manager.relays.values()]): await asyncio.sleep(2) def callback_events(eventMessage: EventMessage): - # print(f"From {event.public_key[:3]}..{event.public_key[-3:]}: {event.content}") - if eventMessage.subscription_id in received_subscription_events: + if eventMessage.subscription_id in NostrRouter.received_subscription_events: # do not add duplicate events (by event id) if eventMessage.event.id in set( [ e.id - for e in received_subscription_events[eventMessage.subscription_id] + for e in NostrRouter.received_subscription_events[eventMessage.subscription_id] ] ): return - received_subscription_events[eventMessage.subscription_id].append( + NostrRouter.received_subscription_events[eventMessage.subscription_id].append( eventMessage.event ) else: - received_subscription_events[eventMessage.subscription_id] = [ + NostrRouter.received_subscription_events[eventMessage.subscription_id] = [ eventMessage.event ] return def callback_notices(noticeMessage: NoticeMessage): - if noticeMessage not in received_subscription_notices: - received_subscription_notices.append(noticeMessage) + if noticeMessage not in NostrRouter.received_subscription_notices: + NostrRouter.received_subscription_notices.append(noticeMessage) return def callback_eose_notices(eventMessage: EndOfStoredEventsMessage): - if eventMessage.subscription_id not in received_subscription_eosenotices: - received_subscription_eosenotices[ + if eventMessage.subscription_id not in NostrRouter.received_subscription_eosenotices: + NostrRouter.received_subscription_eosenotices[ eventMessage.subscription_id ] = eventMessage diff --git a/templates/nostrclient/index.html b/templates/nostrclient/index.html index 10fd4a5..82b149e 100644 --- a/templates/nostrclient/index.html +++ b/templates/nostrclient/index.html @@ -6,18 +6,18 @@
- +
- Add relay - + + + + + + + +
@@ -29,36 +29,18 @@
Nostrclient
- +
- +