diff --git a/__init__.py b/__init__.py
index 019df68..ec6f4b4 100644
--- a/__init__.py
+++ b/__init__.py
@@ -8,6 +8,8 @@ from lnbits.db import Database
from lnbits.helpers import template_renderer
from lnbits.tasks import catch_everything_and_restart
+from .nostr.client.client import NostrClient as NostrClientLib
+
db = Database("ext_nostrclient")
nostrclient_static_files = [
@@ -22,12 +24,19 @@ nostrclient_ext: APIRouter = APIRouter(prefix="/nostrclient", tags=["nostrclient
scheduled_tasks: List[asyncio.Task] = []
+class NostrClient:
+ def __init__(self):
+ self.client: NostrClientLib = NostrClientLib(connect=False)
+
+
+nostr = NostrClient()
+
def nostr_renderer():
return template_renderer(["lnbits/extensions/nostrclient/templates"])
-from .tasks import init_relays, subscribe_events
+from .tasks import check_relays, init_relays, subscribe_events
from .views import * # noqa
from .views_api import * # noqa
@@ -38,3 +47,5 @@ def nostrclient_start():
scheduled_tasks.append(task1)
task2 = loop.create_task(catch_everything_and_restart(subscribe_events))
scheduled_tasks.append(task2)
+ task3 = loop.create_task(catch_everything_and_restart(check_relays))
+ scheduled_tasks.append(task3)
diff --git a/models.py b/models.py
index 1456d83..88651fc 100644
--- a/models.py
+++ b/models.py
@@ -8,12 +8,19 @@ from pydantic import BaseModel, Field
from lnbits.helpers import urlsafe_short_hash
+class RelayStatus(BaseModel):
+ num_sent_events: Optional[int] = 0
+ num_received_events: Optional[int] = 0
+ error_counter: Optional[int] = 0
+ error_list: Optional[List] = []
+ notice_list: Optional[List] = []
+
class Relay(BaseModel):
id: Optional[str] = None
url: Optional[str] = None
connected: Optional[bool] = None
connected_string: Optional[str] = None
- status: Optional[str] = None
+ status: Optional[RelayStatus] = None
active: Optional[bool] = None
ping: Optional[int] = None
@@ -59,43 +66,3 @@ class TestMessageResponse(BaseModel):
private_key: str
public_key: str
event_json: str
-
-# class nostrKeys(BaseModel):
-# pubkey: str
-# privkey: str
-
-# class nostrNotes(BaseModel):
-# id: str
-# pubkey: str
-# created_at: str
-# kind: int
-# tags: str
-# content: str
-# sig: str
-
-# class nostrCreateRelays(BaseModel):
-# relay: str = Query(None)
-
-# class nostrCreateConnections(BaseModel):
-# pubkey: str = Query(None)
-# relayid: str = Query(None)
-
-# class nostrRelays(BaseModel):
-# id: Optional[str]
-# relay: Optional[str]
-# status: Optional[bool] = False
-
-
-# class nostrRelaySetList(BaseModel):
-# allowlist: Optional[str]
-# denylist: Optional[str]
-
-# class nostrConnections(BaseModel):
-# id: str
-# pubkey: Optional[str]
-# relayid: Optional[str]
-
-# class nostrSubscriptions(BaseModel):
-# id: str
-# userPubkey: Optional[str]
-# subscribedPubkey: Optional[str]
diff --git a/nostr/bech32.py b/nostr/bech32.py
index b068de7..61a92c4 100644
--- a/nostr/bech32.py
+++ b/nostr/bech32.py
@@ -23,6 +23,7 @@
from enum import Enum
+
class Encoding(Enum):
"""Enumeration type to list the various supported encodings."""
BECH32 = 1
diff --git a/nostr/client/cbc.py b/nostr/client/cbc.py
deleted file mode 100644
index a41dbc0..0000000
--- a/nostr/client/cbc.py
+++ /dev/null
@@ -1,41 +0,0 @@
-
-from Cryptodome import Random
-from Cryptodome.Cipher import AES
-
-plain_text = "This is the text to encrypts"
-
-# encrypted = "7mH9jq3K9xNfWqIyu9gNpUz8qBvGwsrDJ+ACExdV1DvGgY8q39dkxVKeXD7LWCDrPnoD/ZFHJMRMis8v9lwHfNgJut8EVTMuJJi8oTgJevOBXl+E+bJPwej9hY3k20rgCQistNRtGHUzdWyOv7S1tg==".encode()
-# iv = "GzDzqOVShWu3Pl2313FBpQ==".encode()
-
-key = bytes.fromhex("3aa925cb69eb613e2928f8a18279c78b1dca04541dfd064df2eda66b59880795")
-
-BLOCK_SIZE = 16
-
-class AESCipher(object):
- """This class is compatible with crypto.createCipheriv('aes-256-cbc')
-
- """
- def __init__(self, key=None):
- self.key = key
-
- def pad(self, data):
- length = BLOCK_SIZE - (len(data) % BLOCK_SIZE)
- return data + (chr(length) * length).encode()
-
- def unpad(self, data):
- return data[: -(data[-1] if type(data[-1]) == int else ord(data[-1]))]
-
- def encrypt(self, plain_text):
- cipher = AES.new(self.key, AES.MODE_CBC)
- b = plain_text.encode("UTF-8")
- return cipher.iv, cipher.encrypt(self.pad(b))
-
- def decrypt(self, iv, enc_text):
- cipher = AES.new(self.key, AES.MODE_CBC, iv=iv)
- return self.unpad(cipher.decrypt(enc_text).decode("UTF-8"))
-
-if __name__ == "__main__":
- aes = AESCipher(key=key)
- iv, enc_text = aes.encrypt(plain_text)
- dec_text = aes.decrypt(iv, enc_text)
- print(dec_text)
\ No newline at end of file
diff --git a/nostr/client/client.py b/nostr/client/client.py
index 6e70f71..66d722c 100644
--- a/nostr/client/client.py
+++ b/nostr/client/client.py
@@ -1,38 +1,14 @@
-from typing import *
-import ssl
import time
-import json
-import os
-import base64
+from typing import List
-from ..event import Event
from ..relay_manager import RelayManager
-from ..message_type import ClientMessageType
-from ..key import PrivateKey, PublicKey
-
-from ..filter import Filter, Filters
-from ..event import Event, EventKind, EncryptedDirectMessage
-from ..relay_manager import RelayManager
-from ..message_type import ClientMessageType
-
-# from aes import AESCipher
-from . import cbc
class NostrClient:
- relays = [
- "wss://nostr-pub.wellorder.net",
- "wss://nostr.zebedee.cloud",
- "wss://nodestr.fmt.wiz.biz",
- "wss://nostr.oxtr.dev",
- ] # ["wss://nostr.oxtr.dev"] # ["wss://relay.nostr.info"] "wss://nostr-pub.wellorder.net" "ws://91.237.88.218:2700", "wss://nostrrr.bublina.eu.org", ""wss://nostr-relay.freeberty.net"", , "wss://nostr.oxtr.dev", "wss://relay.nostr.info", "wss://nostr-pub.wellorder.net" , "wss://relayer.fiatjaf.com", "wss://nodestr.fmt.wiz.biz/", "wss://no.str.cr"
+ relays = [ ]
relay_manager = RelayManager()
- private_key: PrivateKey
- public_key: PublicKey
-
- def __init__(self, privatekey_hex: str = "", relays: List[str] = [], connect=True):
- self.generate_keys(privatekey_hex)
+ def __init__(self, relays: List[str] = [], connect=True):
if len(relays):
self.relays = relays
if connect:
@@ -41,94 +17,10 @@ class NostrClient:
def connect(self):
for relay in self.relays:
self.relay_manager.add_relay(relay)
- self.relay_manager.open_connections(
- {"cert_reqs": ssl.CERT_NONE}
- ) # NOTE: This disables ssl certificate verification
def close(self):
self.relay_manager.close_connections()
- def generate_keys(self, privatekey_hex: str = None):
- pk = bytes.fromhex(privatekey_hex) if privatekey_hex else None
- self.private_key = PrivateKey(pk)
- self.public_key = self.private_key.public_key
-
- def post(self, message: str):
- event = Event(message, self.public_key.hex(), kind=EventKind.TEXT_NOTE)
- self.private_key.sign_event(event)
- event_json = event.to_message()
- # print("Publishing message:")
- # print(event_json)
- self.relay_manager.publish_message(event_json)
-
- def get_post(
- self, sender_publickey: PublicKey = None, callback_func=None, filter_kwargs={}
- ):
- filter = Filter(
- authors=[sender_publickey.hex()] if sender_publickey else None,
- kinds=[EventKind.TEXT_NOTE],
- **filter_kwargs,
- )
- filters = Filters([filter])
- subscription_id = os.urandom(4).hex()
- self.relay_manager.add_subscription(subscription_id, filters)
-
- request = [ClientMessageType.REQUEST, subscription_id]
- request.extend(filters.to_json_array())
- message = json.dumps(request)
- self.relay_manager.publish_message(message)
-
- while True:
- while self.relay_manager.message_pool.has_events():
- event_msg = self.relay_manager.message_pool.get_event()
- if callback_func:
- callback_func(event_msg.event)
- time.sleep(0.1)
-
- def dm(self, message: str, to_pubkey: PublicKey):
- dm = EncryptedDirectMessage(
- recipient_pubkey=to_pubkey.hex(), cleartext_content=message
- )
- self.private_key.sign_event(dm)
- self.relay_manager.publish_event(dm)
-
- def get_dm(self, sender_publickey: PublicKey, callback_func=None):
- filters = Filters(
- [
- Filter(
- kinds=[EventKind.ENCRYPTED_DIRECT_MESSAGE],
- pubkey_refs=[sender_publickey.hex()],
- )
- ]
- )
- subscription_id = os.urandom(4).hex()
- self.relay_manager.add_subscription(subscription_id, filters)
-
- request = [ClientMessageType.REQUEST, subscription_id]
- request.extend(filters.to_json_array())
- message = json.dumps(request)
- self.relay_manager.publish_message(message)
-
- while True:
- while self.relay_manager.message_pool.has_events():
- event_msg = self.relay_manager.message_pool.get_event()
- if "?iv=" in event_msg.event.content:
- try:
- shared_secret = self.private_key.compute_shared_secret(
- event_msg.event.public_key
- )
- aes = cbc.AESCipher(key=shared_secret)
- enc_text_b64, iv_b64 = event_msg.event.content.split("?iv=")
- iv = base64.decodebytes(iv_b64.encode("utf-8"))
- enc_text = base64.decodebytes(enc_text_b64.encode("utf-8"))
- dec_text = aes.decrypt(iv, enc_text)
- if callback_func:
- callback_func(event_msg.event, dec_text)
- except:
- pass
- break
- time.sleep(0.1)
-
def subscribe(
self,
callback_events_func=None,
diff --git a/nostr/event.py b/nostr/event.py
index b903e0e..65b187d 100644
--- a/nostr/event.py
+++ b/nostr/event.py
@@ -1,10 +1,11 @@
-import time
import json
+import time
from dataclasses import dataclass, field
from enum import IntEnum
-from typing import List
-from secp256k1 import PublicKey
from hashlib import sha256
+from typing import List
+
+from secp256k1 import PublicKey
from .message_type import ClientMessageType
diff --git a/nostr/key.py b/nostr/key.py
index d34697f..8089e11 100644
--- a/nostr/key.py
+++ b/nostr/key.py
@@ -1,14 +1,15 @@
-import secrets
import base64
-import secp256k1
-from cffi import FFI
-from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
-from cryptography.hazmat.primitives import padding
+import secrets
from hashlib import sha256
+import secp256k1
+from cffi import FFI
+from cryptography.hazmat.primitives import padding
+from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
+
+from . import bech32
from .delegation import Delegation
from .event import EncryptedDirectMessage, Event, EventKind
-from . import bech32
class PublicKey:
diff --git a/nostr/message_pool.py b/nostr/message_pool.py
index d364cf2..02f7fd4 100644
--- a/nostr/message_pool.py
+++ b/nostr/message_pool.py
@@ -1,8 +1,9 @@
import json
from queue import Queue
from threading import Lock
-from .message_type import RelayMessageType
+
from .event import Event
+from .message_type import RelayMessageType
class EventMessage:
@@ -68,10 +69,19 @@ class MessagePool:
e["sig"],
)
with self.lock:
- if not event.id in self._unique_events:
- self.events.put(EventMessage(event, subscription_id, url))
- self._unique_events.add(event.id)
+ if not f"{subscription_id}_{event.id}" in self._unique_events:
+ self._accept_event(EventMessage(event, subscription_id, url))
elif message_type == RelayMessageType.NOTICE:
self.notices.put(NoticeMessage(message_json[1], url))
elif message_type == RelayMessageType.END_OF_STORED_EVENTS:
self.eose_notices.put(EndOfStoredEventsMessage(message_json[1], url))
+
+ def _accept_event(self, event_message: EventMessage):
+ """
+ Event uniqueness is considered per `subscription_id`.
+ The `subscription_id` is rewritten to be unique and it is the same accross relays.
+ The same event can come from different subscriptions (from the same client or from different ones).
+ Clients that have joined later should receive older events.
+ """
+ self.events.put(event_message)
+ self._unique_events.add(f"{event_message.subscription_id}_{event_message.event.id}")
\ No newline at end of file
diff --git a/nostr/message_type.py b/nostr/message_type.py
index 3f5206b..d37cdfd 100644
--- a/nostr/message_type.py
+++ b/nostr/message_type.py
@@ -3,13 +3,20 @@ class ClientMessageType:
REQUEST = "REQ"
CLOSE = "CLOSE"
+
class RelayMessageType:
EVENT = "EVENT"
NOTICE = "NOTICE"
END_OF_STORED_EVENTS = "EOSE"
+ COMMAND_RESULT = "OK"
@staticmethod
def is_valid(type: str) -> bool:
- if type == RelayMessageType.EVENT or type == RelayMessageType.NOTICE or type == RelayMessageType.END_OF_STORED_EVENTS:
+ if (
+ type == RelayMessageType.EVENT
+ or type == RelayMessageType.NOTICE
+ or type == RelayMessageType.END_OF_STORED_EVENTS
+ or type == RelayMessageType.COMMAND_RESULT
+ ):
return True
- return False
\ No newline at end of file
+ return False
diff --git a/nostr/pow.py b/nostr/pow.py
deleted file mode 100644
index e006288..0000000
--- a/nostr/pow.py
+++ /dev/null
@@ -1,54 +0,0 @@
-import time
-from .event import Event
-from .key import PrivateKey
-
-def zero_bits(b: int) -> int:
- n = 0
-
- if b == 0:
- return 8
-
- while b >> 1:
- b = b >> 1
- n += 1
-
- return 7 - n
-
-def count_leading_zero_bits(hex_str: str) -> int:
- total = 0
- for i in range(0, len(hex_str) - 2, 2):
- bits = zero_bits(int(hex_str[i:i+2], 16))
- total += bits
-
- if bits != 8:
- break
-
- return total
-
-def mine_event(content: str, difficulty: int, public_key: str, kind: int, tags: list=[]) -> Event:
- all_tags = [["nonce", "1", str(difficulty)]]
- all_tags.extend(tags)
-
- created_at = int(time.time())
- event_id = Event.compute_id(public_key, created_at, kind, all_tags, content)
- num_leading_zero_bits = count_leading_zero_bits(event_id)
-
- attempts = 1
- while num_leading_zero_bits < difficulty:
- attempts += 1
- all_tags[0][1] = str(attempts)
- created_at = int(time.time())
- event_id = Event.compute_id(public_key, created_at, kind, all_tags, content)
- num_leading_zero_bits = count_leading_zero_bits(event_id)
-
- return Event(public_key, content, created_at, kind, all_tags, event_id)
-
-def mine_key(difficulty: int) -> PrivateKey:
- sk = PrivateKey()
- num_leading_zero_bits = count_leading_zero_bits(sk.public_key.hex())
-
- while num_leading_zero_bits < difficulty:
- sk = PrivateKey()
- num_leading_zero_bits = count_leading_zero_bits(sk.public_key.hex())
-
- return sk
diff --git a/nostr/relay.py b/nostr/relay.py
index 7fb4baa..8b081a3 100644
--- a/nostr/relay.py
+++ b/nostr/relay.py
@@ -2,7 +2,11 @@ import json
import time
from queue import Queue
from threading import Lock
+from typing import List
+
+from loguru import logger
from websocket import WebSocketApp
+
from .event import Event
from .filter import Filters
from .message_pool import MessagePool
@@ -36,6 +40,9 @@ class Relay:
self.shutdown: bool = False
self.error_counter: int = 0
self.error_threshold: int = 100
+ self.error_list: List[str] = []
+ self.notice_list: List[str] = []
+ self.last_error_date: int = 0
self.num_received_events: int = 0
self.num_sent_events: int = 0
self.num_subscriptions: int = 0
@@ -67,17 +74,12 @@ class Relay:
def close(self):
self.ws.close()
+ self.connected = False
self.shutdown = True
- def check_reconnect(self):
- try:
- self.close()
- except:
- pass
- self.connected = False
- if self.reconnect:
- time.sleep(self.error_counter**2)
- self.connect(self.ssl_options, self.proxy)
+ @property
+ def error_threshold_reached(self):
+ return self.error_threshold and self.error_counter >= self.error_threshold
@property
def ping(self):
@@ -87,15 +89,22 @@ class Relay:
def publish(self, message: str):
self.queue.put(message)
- def queue_worker(self, shutdown):
+ def publish_subscriptions(self):
+ for _, subscription in self.subscriptions.items():
+ s = subscription.to_json_object()
+ json_str = json.dumps(["REQ", s["id"], s["filters"][0]])
+ self.publish(json_str)
+
+ def queue_worker(self):
while True:
if self.connected:
try:
message = self.queue.get(timeout=1)
self.num_sent_events += 1
self.ws.send(message)
- except:
- if shutdown():
+ except Exception as e:
+ if self.shutdown:
+ logger.warning(f"Closing queue worker for '{self.url}'.")
break
else:
time.sleep(0.1)
@@ -107,11 +116,7 @@ class Relay:
def close_subscription(self, id: str) -> None:
with self.lock:
self.subscriptions.pop(id)
-
- def update_subscription(self, id: str, filters: Filters) -> None:
- with self.lock:
- subscription = self.subscriptions[id]
- subscription.filters = filters
+ self.publish(json.dumps(["CLOSE", id]))
def to_json_object(self) -> dict:
return {
@@ -123,31 +128,32 @@ class Relay:
],
}
- def _on_open(self, class_obj):
+ def add_notice(self, notice: str):
+ self.notice_list = ([notice] + self.notice_list)[:20]
+
+ def _on_open(self, _):
+ logger.info(f"Connected to relay: '{self.url}'.")
self.connected = True
- pass
+
+ def _on_close(self, _, status_code, message):
+ logger.warning(f"Connection to relay {self.url} closed. Status: '{status_code}'. Message: '{message}'.")
+ self.close()
- def _on_close(self, class_obj, status_code, message):
- self.connected = False
- if self.error_threshold and self.error_counter > self.error_threshold:
- pass
- else:
- self.check_reconnect()
- pass
-
- def _on_message(self, class_obj, message: str):
+ def _on_message(self, _, message: str):
if self._is_valid_message(message):
self.num_received_events += 1
self.message_pool.add_message(message, self.url)
- def _on_error(self, class_obj, error):
+ def _on_error(self, _, error):
+ logger.warning(f"Relay error: '{str(error)}'")
+ self._append_error_message(str(error))
self.connected = False
self.error_counter += 1
- def _on_ping(self, class_obj, message):
+ def _on_ping(self, *_):
return
- def _on_pong(self, class_obj, message):
+ def _on_pong(self, *_):
return
def _is_valid_message(self, message: str) -> bool:
@@ -157,33 +163,58 @@ class Relay:
message_json = json.loads(message)
message_type = message_json[0]
+
if not RelayMessageType.is_valid(message_type):
return False
+
if message_type == RelayMessageType.EVENT:
- if not len(message_json) == 3:
- return False
-
- subscription_id = message_json[1]
- with self.lock:
- if subscription_id not in self.subscriptions:
- return False
-
- e = message_json[2]
- event = Event(
- e["content"],
- e["pubkey"],
- e["created_at"],
- e["kind"],
- e["tags"],
- e["sig"],
- )
- if not event.verify():
- return False
-
- with self.lock:
- subscription = self.subscriptions[subscription_id]
-
- if subscription.filters and not subscription.filters.match(event):
- return False
+ return self._is_valid_event_message(message_json)
+
+ if message_type == RelayMessageType.COMMAND_RESULT:
+ return self._is_valid_command_result_message(message, message_json)
return True
+
+ def _is_valid_event_message(self, message_json):
+ if not len(message_json) == 3:
+ return False
+
+ subscription_id = message_json[1]
+ with self.lock:
+ if subscription_id not in self.subscriptions:
+ return False
+
+ e = message_json[2]
+ event = Event(
+ e["content"],
+ e["pubkey"],
+ e["created_at"],
+ e["kind"],
+ e["tags"],
+ e["sig"],
+ )
+ if not event.verify():
+ return False
+
+ with self.lock:
+ subscription = self.subscriptions[subscription_id]
+
+ if subscription.filters and not subscription.filters.match(event):
+ return False
+
+ return True
+
+ def _is_valid_command_result_message(self, message, message_json):
+ if not len(message_json) < 3:
+ return False
+
+ if message_json[2] != True:
+ logger.warning(f"Relay '{self.url}' negative command result: '{message}'")
+ self._append_error_message(message)
+ return False
+
+ return True
+
+ def _append_error_message(self, message):
+ self.error_list = ([message] + self.error_list)[:20]
+ self.last_error_date = int(time.time())
\ No newline at end of file
diff --git a/nostr/relay_manager.py b/nostr/relay_manager.py
index a698a33..b2df735 100644
--- a/nostr/relay_manager.py
+++ b/nostr/relay_manager.py
@@ -1,11 +1,14 @@
-import json
-import threading
-from .event import Event
+import ssl
+import threading
+import time
+
+from loguru import logger
+
from .filter import Filters
-from .message_pool import MessagePool
-from .message_type import ClientMessageType
+from .message_pool import MessagePool, NoticeMessage
from .relay import Relay, RelayPolicy
+from .subscription import Subscription
class RelayException(Exception):
@@ -18,45 +21,55 @@ class RelayManager:
self.threads: dict[str, threading.Thread] = {}
self.queue_threads: dict[str, threading.Thread] = {}
self.message_pool = MessagePool()
+ self._cached_subscriptions: dict[str, Subscription] = {}
+ self._subscriptions_lock = threading.Lock()
+
+ def add_relay(self, url: str, read: bool = True, write: bool = True) -> Relay:
+ if url in list(self.relays.keys()):
+ return
+
+ with self._subscriptions_lock:
+ subscriptions = self._cached_subscriptions.copy()
- def add_relay(
- self, url: str, read: bool = True, write: bool = True, subscriptions={}
- ):
policy = RelayPolicy(read, write)
relay = Relay(url, policy, self.message_pool, subscriptions)
self.relays[url] = relay
+ self._open_connection(
+ relay,
+ {"cert_reqs": ssl.CERT_NONE}
+ ) # NOTE: This disables ssl certificate verification
+
+ relay.publish_subscriptions()
+ return relay
+
def remove_relay(self, url: str):
- self.relays[url].close()
- self.relays.pop(url)
self.threads[url].join(timeout=1)
self.threads.pop(url)
+ self.queue_threads[url].join(timeout=1)
+ self.queue_threads.pop(url)
+ self.relays[url].close()
+ self.relays.pop(url)
def add_subscription(self, id: str, filters: Filters):
+ with self._subscriptions_lock:
+ self._cached_subscriptions[id] = Subscription(id, filters)
+
for relay in self.relays.values():
relay.add_subscription(id, filters)
def close_subscription(self, id: str):
+ with self._subscriptions_lock:
+ self._cached_subscriptions.pop(id)
+
for relay in self.relays.values():
relay.close_subscription(id)
- def open_connections(self, ssl_options: dict = None, proxy: dict = None):
- for relay in self.relays.values():
- self.threads[relay.url] = threading.Thread(
- target=relay.connect,
- args=(ssl_options, proxy),
- name=f"{relay.url}-thread",
- daemon=True,
- )
- self.threads[relay.url].start()
+ def check_and_restart_relays(self):
+ stopped_relays = [r for r in self.relays.values() if r.shutdown]
+ for relay in stopped_relays:
+ self._restart_relay(relay)
- self.queue_threads[relay.url] = threading.Thread(
- target=relay.queue_worker,
- args=(lambda: relay.shutdown,),
- name=f"{relay.url}-queue",
- daemon=True,
- )
- self.queue_threads[relay.url].start()
def close_connections(self):
for relay in self.relays.values():
@@ -67,13 +80,38 @@ class RelayManager:
if relay.policy.should_write:
relay.publish(message)
- def publish_event(self, event: Event):
- """Verifies that the Event is publishable before submitting it to relays"""
- if event.signature is None:
- raise RelayException(f"Could not publish {event.id}: must be signed")
+ def handle_notice(self, notice: NoticeMessage):
+ relay = next((r for r in self.relays.values() if r.url == notice.url))
+ if relay:
+ relay.add_notice(notice.content)
- if not event.verify():
- raise RelayException(
- f"Could not publish {event.id}: failed to verify signature {event.signature}"
- )
- self.publish_message(event.to_message())
+ def _open_connection(self, relay: Relay, ssl_options: dict = None, proxy: dict = None):
+ self.threads[relay.url] = threading.Thread(
+ target=relay.connect,
+ args=(ssl_options, proxy),
+ name=f"{relay.url}-thread",
+ daemon=True,
+ )
+ self.threads[relay.url].start()
+
+ self.queue_threads[relay.url] = threading.Thread(
+ target=relay.queue_worker,
+ name=f"{relay.url}-queue",
+ daemon=True,
+ )
+ self.queue_threads[relay.url].start()
+
+ def _restart_relay(self, relay: Relay):
+ if relay.error_threshold_reached:
+ time_since_last_error = time.time() - relay.last_error_date
+ if time_since_last_error < 60 * 60 * 2: # last day
+ return
+ relay.error_counter = 0
+ relay.error_list = []
+
+ logger.info(f"Restarting connection to relay '{relay.url}'")
+
+ self.remove_relay(relay.url)
+ new_relay = self.add_relay(relay.url)
+ new_relay.error_counter = relay.error_counter
+ new_relay.error_list = relay.error_list
\ No newline at end of file
diff --git a/nostr/subscription.py b/nostr/subscription.py
index 7afba20..76da0af 100644
--- a/nostr/subscription.py
+++ b/nostr/subscription.py
@@ -1,5 +1,6 @@
from .filter import Filters
+
class Subscription:
def __init__(self, id: str, filters: Filters=None) -> None:
self.id = id
diff --git a/router.py b/router.py
new file mode 100644
index 0000000..e85653c
--- /dev/null
+++ b/router.py
@@ -0,0 +1,190 @@
+import asyncio
+import json
+from typing import List, Union
+
+from fastapi import WebSocketDisconnect
+from loguru import logger
+
+from lnbits.helpers import urlsafe_short_hash
+
+from . import nostr
+from .models import Event, Filter
+from .nostr.filter import Filter as NostrFilter
+from .nostr.filter import Filters as NostrFilters
+from .nostr.message_pool import EndOfStoredEventsMessage, NoticeMessage
+
+
+class NostrRouter:
+
+ received_subscription_events: dict[str, list[Event]] = {}
+ received_subscription_notices: list[NoticeMessage] = []
+ received_subscription_eosenotices: dict[str, EndOfStoredEventsMessage] = {}
+
+ def __init__(self, websocket):
+ self.subscriptions: List[str] = []
+ self.connected: bool = True
+ self.websocket = websocket
+ self.tasks: List[asyncio.Task] = []
+ self.original_subscription_ids = {}
+
+ async def client_to_nostr(self):
+ """Receives requests / data from the client and forwards it to relays. If the
+ request was a subscription/filter, registers it with the nostr client lib.
+ Remembers the subscription id so we can send back responses from the relay to this
+ client in `nostr_to_client`"""
+ while True:
+ try:
+ json_str = await self.websocket.receive_text()
+ except WebSocketDisconnect:
+ self.connected = False
+ break
+
+ try:
+ await self._handle_client_to_nostr(json_str)
+ except Exception as e:
+ logger.debug(f"Failed to handle client message: '{str(e)}'.")
+
+
+ async def nostr_to_client(self):
+ """Sends responses from relays back to the client. Polls the subscriptions of this client
+ stored in `my_subscriptions`. Then gets all responses for this subscription id from `received_subscription_events` which
+ is filled in tasks.py. Takes one response after the other and relays it back to the client. Reconstructs
+ the reponse manually because the nostr client lib we're using can't do it. Reconstructs the original subscription id
+ that we had previously rewritten in order to avoid collisions when multiple clients use the same id.
+ """
+ while True and self.connected:
+ try:
+ await self._handle_subscriptions()
+ self._handle_notices()
+ except Exception as e:
+ logger.debug(f"Failed to handle response for client: '{str(e)}'.")
+ await asyncio.sleep(0.1)
+
+
+ async def start(self):
+ self.tasks.append(asyncio.create_task(self.client_to_nostr()))
+ self.tasks.append(asyncio.create_task(self.nostr_to_client()))
+
+ async def stop(self):
+ for t in self.tasks:
+ try:
+ t.cancel()
+ except:
+ pass
+
+ for s in self.subscriptions:
+ try:
+ nostr.client.relay_manager.close_subscription(s)
+ except:
+ pass
+ self.connected = False
+
+ async def _handle_subscriptions(self):
+ for s in self.subscriptions:
+ if s in NostrRouter.received_subscription_events:
+ await self._handle_received_subscription_events(s)
+ if s in NostrRouter.received_subscription_eosenotices:
+ await self._handle_received_subscription_eosenotices(s)
+
+
+
+ async def _handle_received_subscription_eosenotices(self, s):
+ s_original = self.original_subscription_ids[s]
+ event_to_forward = ["EOSE", s_original]
+ del NostrRouter.received_subscription_eosenotices[s]
+
+ await self.websocket.send_text(json.dumps(event_to_forward))
+
+ async def _handle_received_subscription_events(self, s):
+ while len(NostrRouter.received_subscription_events[s]):
+ my_event = NostrRouter.received_subscription_events[s].pop(0)
+ # event.to_message() does not include the subscription ID, we have to add it manually
+ event_json = {
+ "id": my_event.id,
+ "pubkey": my_event.public_key,
+ "created_at": my_event.created_at,
+ "kind": my_event.kind,
+ "tags": my_event.tags,
+ "content": my_event.content,
+ "sig": my_event.signature,
+ }
+
+ # this reconstructs the original response from the relay
+ # reconstruct original subscription id
+ s_original = self.original_subscription_ids[s]
+ event_to_forward = ["EVENT", s_original, event_json]
+ await self.websocket.send_text(json.dumps(event_to_forward))
+
+ def _handle_notices(self):
+ while len(NostrRouter.received_subscription_notices):
+ my_event = NostrRouter.received_subscription_notices.pop(0)
+ # note: we don't send it to the user because we don't know who should receive it
+ logger.info(f"Relay ('{my_event.url}') notice: '{my_event.content}']")
+ nostr.client.relay_manager.handle_notice(my_event)
+
+
+
+ def _marshall_nostr_filters(self, data: Union[dict, list]):
+ filters = data if isinstance(data, list) else [data]
+ filters = [Filter.parse_obj(f) for f in filters]
+ filter_list: list[NostrFilter] = []
+ for filter in filters:
+ filter_list.append(
+ NostrFilter(
+ event_ids=filter.ids, # type: ignore
+ kinds=filter.kinds, # type: ignore
+ authors=filter.authors, # type: ignore
+ since=filter.since, # type: ignore
+ until=filter.until, # type: ignore
+ event_refs=filter.e, # type: ignore
+ pubkey_refs=filter.p, # type: ignore
+ limit=filter.limit, # type: ignore
+ )
+ )
+ return NostrFilters(filter_list)
+
+ async def _handle_client_to_nostr(self, json_str):
+ """Parses a (string) request from a client. If it is a subscription (REQ) or a CLOSE, it will
+ register the subscription in the nostr client library that we're using so we can
+ receive the callbacks on it later. Will rewrite the subscription id since we expect
+ multiple clients to use the router and want to avoid subscription id collisions
+ """
+
+ json_data = json.loads(json_str)
+ assert len(json_data)
+
+
+ if json_data[0] == "REQ":
+ self._handle_client_req(json_data)
+ return
+
+ if json_data[0] == "CLOSE":
+ self._handle_client_close(json_data[1])
+ return
+
+ if json_data[0] == "EVENT":
+ nostr.client.relay_manager.publish_message(json_str)
+ return
+
+ def _handle_client_req(self, json_data):
+ subscription_id = json_data[1]
+ subscription_id_rewritten = urlsafe_short_hash()
+ self.original_subscription_ids[subscription_id_rewritten] = subscription_id
+ fltr = json_data[2]
+ filters = self._marshall_nostr_filters(fltr)
+
+ nostr.client.relay_manager.add_subscription(
+ subscription_id_rewritten, filters
+ )
+ request_rewritten = json.dumps([json_data[0], subscription_id_rewritten, fltr])
+
+ self.subscriptions.append(subscription_id_rewritten)
+ nostr.client.relay_manager.publish_message(request_rewritten)
+
+ def _handle_client_close(self, subscription_id):
+ subscription_id_rewritten = next((k for k, v in self.original_subscription_ids.items() if v == subscription_id), None)
+ if subscription_id_rewritten:
+ self.original_subscription_ids.pop(subscription_id_rewritten)
+ nostr.client.relay_manager.close_subscription(subscription_id_rewritten)
+ else:
+ logger.debug(f"Failed to unsubscribe from '{subscription_id}.'")
diff --git a/services.py b/services.py
deleted file mode 100644
index 82f6578..0000000
--- a/services.py
+++ /dev/null
@@ -1,163 +0,0 @@
-import asyncio
-import json
-from typing import List, Union
-
-from fastapi import WebSocket, WebSocketDisconnect
-from loguru import logger
-
-from lnbits.helpers import urlsafe_short_hash
-
-from .models import Event, Filter, Filters, Relay, RelayList
-from .nostr.client.client import NostrClient as NostrClientLib
-from .nostr.event import Event as NostrEvent
-from .nostr.filter import Filter as NostrFilter
-from .nostr.filter import Filters as NostrFilters
-from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage
-
-received_subscription_events: dict[str, list[Event]] = {}
-received_subscription_notices: list[NoticeMessage] = []
-received_subscription_eosenotices: dict[str, EndOfStoredEventsMessage] = {}
-
-
-class NostrClient:
- def __init__(self):
- self.client: NostrClientLib = NostrClientLib(connect=False)
-
-
-nostr = NostrClient()
-
-
-class NostrRouter:
- def __init__(self, websocket):
- self.subscriptions: List[str] = []
- self.connected: bool = True
- self.websocket = websocket
- self.tasks: List[asyncio.Task] = []
- self.oridinal_subscription_ids = {}
-
- async def client_to_nostr(self):
- """Receives requests / data from the client and forwards it to relays. If the
- request was a subscription/filter, registers it with the nostr client lib.
- Remembers the subscription id so we can send back responses from the relay to this
- client in `nostr_to_client`"""
- while True:
- try:
- json_str = await self.websocket.receive_text()
- except WebSocketDisconnect:
- self.connected = False
- break
-
- # registers a subscription if the input was a REQ request
- subscription_id, json_str_rewritten = await self._add_nostr_subscription(
- json_str
- )
-
- if subscription_id and json_str_rewritten:
- self.subscriptions.append(subscription_id)
- json_str = json_str_rewritten
-
- # publish data
- nostr.client.relay_manager.publish_message(json_str)
-
- async def nostr_to_client(self):
- """Sends responses from relays back to the client. Polls the subscriptions of this client
- stored in `my_subscriptions`. Then gets all responses for this subscription id from `received_subscription_events` which
- is filled in tasks.py. Takes one response after the other and relays it back to the client. Reconstructs
- the reponse manually because the nostr client lib we're using can't do it. Reconstructs the original subscription id
- that we had previously rewritten in order to avoid collisions when multiple clients use the same id.
- """
- while True and self.connected:
- for s in self.subscriptions:
- if s in received_subscription_events:
- while len(received_subscription_events[s]):
- my_event = received_subscription_events[s].pop(0)
- # event.to_message() does not include the subscription ID, we have to add it manually
- event_json = {
- "id": my_event.id,
- "pubkey": my_event.public_key,
- "created_at": my_event.created_at,
- "kind": my_event.kind,
- "tags": my_event.tags,
- "content": my_event.content,
- "sig": my_event.signature,
- }
-
- # this reconstructs the original response from the relay
- # reconstruct original subscription id
- s_original = self.oridinal_subscription_ids[s]
- event_to_forward = ["EVENT", s_original, event_json]
-
- # print("Event to forward")
- # print(json.dumps(event_to_forward))
-
- # send data back to client
- await self.websocket.send_text(json.dumps(event_to_forward))
- if s in received_subscription_eosenotices:
- my_event = received_subscription_eosenotices[s]
- s_original = self.oridinal_subscription_ids[s]
- event_to_forward = ["EOSE", s_original]
- del received_subscription_eosenotices[s]
- # send data back to client
- # print("Sending EOSE", event_to_forward)
- await self.websocket.send_text(json.dumps(event_to_forward))
-
- # if s in received_subscription_notices:
- while len(received_subscription_notices):
- my_event = received_subscription_notices.pop(0)
- event_to_forward = ["NOTICE", my_event.content]
- # send data back to client
- logger.debug("Nostrclient: Received notice", event_to_forward[1])
- # note: we don't send it to the user because we don't know who should receive it
- # await self.websocket.send_text(json.dumps(event_to_forward))
- await asyncio.sleep(0.1)
-
- async def start(self):
- self.tasks.append(asyncio.create_task(self.client_to_nostr()))
- self.tasks.append(asyncio.create_task(self.nostr_to_client()))
-
- async def stop(self):
- for t in self.tasks:
- t.cancel()
- self.connected = False
-
- def _marshall_nostr_filters(self, data: Union[dict, list]):
- filters = data if isinstance(data, list) else [data]
- filters = [Filter.parse_obj(f) for f in filters]
- filter_list: list[NostrFilter] = []
- for filter in filters:
- filter_list.append(
- NostrFilter(
- event_ids=filter.ids, # type: ignore
- kinds=filter.kinds, # type: ignore
- authors=filter.authors, # type: ignore
- since=filter.since, # type: ignore
- until=filter.until, # type: ignore
- event_refs=filter.e, # type: ignore
- pubkey_refs=filter.p, # type: ignore
- limit=filter.limit, # type: ignore
- )
- )
- return NostrFilters(filter_list)
-
- async def _add_nostr_subscription(self, json_str):
- """Parses a (string) request from a client. If it is a subscription (REQ) or a CLOSE, it will
- register the subscription in the nostr client library that we're using so we can
- receive the callbacks on it later. Will rewrite the subscription id since we expect
- multiple clients to use the router and want to avoid subscription id collisions
- """
- json_data = json.loads(json_str)
- assert len(json_data)
- if json_data[0] in ["REQ", "CLOSE"]:
- subscription_id = json_data[1]
- subscription_id_rewritten = urlsafe_short_hash()
- self.oridinal_subscription_ids[subscription_id_rewritten] = subscription_id
- fltr = json_data[2]
- filters = self._marshall_nostr_filters(fltr)
- nostr.client.relay_manager.add_subscription(
- subscription_id_rewritten, filters
- )
- request_rewritten = json.dumps(
- [json_data[0], subscription_id_rewritten, fltr]
- )
- return subscription_id_rewritten, request_rewritten
- return None, None
diff --git a/tasks.py b/tasks.py
index beff9db..05057e7 100644
--- a/tasks.py
+++ b/tasks.py
@@ -1,82 +1,66 @@
import asyncio
-import json
-import ssl
import threading
+from loguru import logger
+
+from . import nostr
from .crud import get_relays
-from .nostr.event import Event
-from .nostr.key import PublicKey
from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage
-from .nostr.relay_manager import RelayManager
-from .services import (
- nostr,
- received_subscription_eosenotices,
- received_subscription_events,
- received_subscription_notices,
-)
+from .router import NostrRouter, nostr
async def init_relays():
- # we save any subscriptions teporarily to re-add them after reinitializing the client
- subscriptions = {}
- for relay in nostr.client.relay_manager.relays.values():
- # relay.add_subscription(id, filters)
- for subscription_id, filters in relay.subscriptions.items():
- subscriptions[subscription_id] = filters
-
# reinitialize the entire client
nostr.__init__()
# get relays from db
relays = await get_relays()
# set relays and connect to them
nostr.client.relays = list(set([r.url for r in relays.__root__ if r.url]))
- nostr.client.connect()
+ await nostr.client.connect()
- await asyncio.sleep(2)
- # re-add subscriptions
- for subscription_id, subscription in subscriptions.items():
- nostr.client.relay_manager.add_subscription(
- subscription_id, subscription.filters
- )
- s = subscription.to_json_object()
- json_str = json.dumps(["REQ", s["id"], s["filters"][0]])
- nostr.client.relay_manager.publish_message(json_str)
- return
+async def check_relays():
+ """ Check relays that have been disconnected """
+ while True:
+ try:
+ await asyncio.sleep(20)
+ nostr.client.relay_manager.check_and_restart_relays()
+ except Exception as e:
+ logger.warning(f"Cannot restart relays: '{str(e)}'.")
+
async def subscribe_events():
while not any([r.connected for r in nostr.client.relay_manager.relays.values()]):
await asyncio.sleep(2)
def callback_events(eventMessage: EventMessage):
- # print(f"From {event.public_key[:3]}..{event.public_key[-3:]}: {event.content}")
- if eventMessage.subscription_id in received_subscription_events:
+ if eventMessage.subscription_id in NostrRouter.received_subscription_events:
# do not add duplicate events (by event id)
if eventMessage.event.id in set(
[
e.id
- for e in received_subscription_events[eventMessage.subscription_id]
+ for e in NostrRouter.received_subscription_events[eventMessage.subscription_id]
]
):
return
- received_subscription_events[eventMessage.subscription_id].append(
+ NostrRouter.received_subscription_events[eventMessage.subscription_id].append(
eventMessage.event
)
else:
- received_subscription_events[eventMessage.subscription_id] = [
+ NostrRouter.received_subscription_events[eventMessage.subscription_id] = [
eventMessage.event
]
return
def callback_notices(noticeMessage: NoticeMessage):
- if noticeMessage not in received_subscription_notices:
- received_subscription_notices.append(noticeMessage)
+ if noticeMessage not in NostrRouter.received_subscription_notices:
+ NostrRouter.received_subscription_notices.append(noticeMessage)
return
def callback_eose_notices(eventMessage: EndOfStoredEventsMessage):
- if eventMessage.subscription_id not in received_subscription_eosenotices:
- received_subscription_eosenotices[
+ if eventMessage.subscription_id not in NostrRouter.received_subscription_eosenotices:
+ NostrRouter.received_subscription_eosenotices[
eventMessage.subscription_id
] = eventMessage
diff --git a/templates/nostrclient/index.html b/templates/nostrclient/index.html
index 10fd4a5..82b149e 100644
--- a/templates/nostrclient/index.html
+++ b/templates/nostrclient/index.html
@@ -6,18 +6,18 @@
-