Stabilize (#24)

* refactor: clean-up

* refactor: extra logs plus try-catch

* refactor: do not use bare `except`

* refactor: clean-up redundant fields

* chore: pass code checks

* chore: code format

* refactor: code clean-up

* fix: refactoring stuff

* refactor: remove un-used file

* chore: code clean-up

* chore: code clean-up

* chore: code-format fix

* refactor: remove nostr.client wrapper

* refactor: code clean-up

* chore: code format

* refactor: remove `RelayList` class

* refactor: extract smaller methods with try-catch

* fix: better exception handling

* fix: remove redundant filters

* fix: simplify event

* chore: code format

* fix: code check

* fix: code check

* fix: simplify `REQ`

* fix: more clean-ups

* refactor: use simpler method

* refactor: re-order and rename

* fix: stop logic

* fix: subscription close before disconnect

* chore: play commit
This commit is contained in:
Vlad Stan 2023-11-01 17:46:42 +02:00 committed by GitHub
parent ab185bd2c4
commit 16ae9d15a1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 522 additions and 717 deletions

View file

@ -7,7 +7,7 @@ from lnbits.db import Database
from lnbits.helpers import template_renderer from lnbits.helpers import template_renderer
from lnbits.tasks import catch_everything_and_restart from lnbits.tasks import catch_everything_and_restart
from .nostr.client.client import NostrClient as NostrClientLib from .nostr.client.client import NostrClient
db = Database("ext_nostrclient") db = Database("ext_nostrclient")
@ -22,19 +22,14 @@ nostrclient_ext: APIRouter = APIRouter(prefix="/nostrclient", tags=["nostrclient
scheduled_tasks: List[asyncio.Task] = [] scheduled_tasks: List[asyncio.Task] = []
class NostrClient: nostr_client = NostrClient()
def __init__(self):
self.client: NostrClientLib = NostrClientLib(connect=False)
nostr = NostrClient()
def nostr_renderer(): def nostr_renderer():
return template_renderer(["nostrclient/templates"]) return template_renderer(["nostrclient/templates"])
from .tasks import check_relays, init_relays, subscribe_events from .tasks import check_relays, init_relays, subscribe_events # noqa
from .views import * # noqa from .views import * # noqa
from .views_api import * # noqa from .views_api import * # noqa

26
cbc.py
View file

@ -1,26 +0,0 @@
from Cryptodome.Cipher import AES
BLOCK_SIZE = 16
class AESCipher(object):
"""This class is compatible with crypto.createCipheriv('aes-256-cbc')"""
def __init__(self, key=None):
self.key = key
def pad(self, data):
length = BLOCK_SIZE - (len(data) % BLOCK_SIZE)
return data + (chr(length) * length).encode()
def unpad(self, data):
return data[: -(data[-1] if type(data[-1]) == int else ord(data[-1]))]
def encrypt(self, plain_text):
cipher = AES.new(self.key, AES.MODE_CBC)
b = plain_text.encode("UTF-8")
return cipher.iv, cipher.encrypt(self.pad(b))
def decrypt(self, iv, enc_text):
cipher = AES.new(self.key, AES.MODE_CBC, iv=iv)
return self.unpad(cipher.decrypt(enc_text).decode("UTF-8"))

16
crud.py
View file

@ -1,21 +1,17 @@
from typing import List, Optional, Union from typing import List
import shortuuid
from lnbits.helpers import urlsafe_short_hash
from . import db from . import db
from .models import Relay, RelayList from .models import Relay
async def get_relays() -> RelayList: async def get_relays() -> List[Relay]:
row = await db.fetchall("SELECT * FROM nostrclient.relays") rows = await db.fetchall("SELECT * FROM nostrclient.relays")
return RelayList(__root__=row) return [Relay.from_row(r) for r in rows]
async def add_relay(relay: Relay) -> None: async def add_relay(relay: Relay) -> None:
await db.execute( await db.execute(
f""" """
INSERT INTO nostrclient.relays ( INSERT INTO nostrclient.relays (
id, id,
url, url,

View file

@ -3,7 +3,7 @@ async def m001_initial(db):
Initial nostrclient table. Initial nostrclient table.
""" """
await db.execute( await db.execute(
f""" """
CREATE TABLE nostrclient.relays ( CREATE TABLE nostrclient.relays (
id TEXT NOT NULL PRIMARY KEY, id TEXT NOT NULL PRIMARY KEY,
url TEXT NOT NULL, url TEXT NOT NULL,

View file

@ -1,9 +1,7 @@
from dataclasses import dataclass from sqlite3 import Row
from typing import Dict, List, Optional from typing import List, Optional
from fastapi import Request from pydantic import BaseModel
from fastapi.param_functions import Query
from pydantic import BaseModel, Field
from lnbits.helpers import urlsafe_short_hash from lnbits.helpers import urlsafe_short_hash
@ -14,7 +12,8 @@ class RelayStatus(BaseModel):
error_counter: Optional[int] = 0 error_counter: Optional[int] = 0
error_list: Optional[List] = [] error_list: Optional[List] = []
notice_list: Optional[List] = [] notice_list: Optional[List] = []
class Relay(BaseModel): class Relay(BaseModel):
id: Optional[str] = None id: Optional[str] = None
url: Optional[str] = None url: Optional[str] = None
@ -28,33 +27,9 @@ class Relay(BaseModel):
if not self.id: if not self.id:
self.id = urlsafe_short_hash() self.id = urlsafe_short_hash()
@classmethod
class RelayList(BaseModel): def from_row(cls, row: Row) -> "Relay":
__root__: List[Relay] return cls(**dict(row))
class Event(BaseModel):
content: str
pubkey: str
created_at: Optional[int]
kind: int
tags: Optional[List[List[str]]]
sig: str
class Filter(BaseModel):
ids: Optional[List[str]]
kinds: Optional[List[int]]
authors: Optional[List[str]]
since: Optional[int]
until: Optional[int]
e: Optional[List[str]] = Field(alias="#e")
p: Optional[List[str]] = Field(alias="#p")
limit: Optional[int]
class Filters(BaseModel):
__root__: List[Filter]
class TestMessage(BaseModel): class TestMessage(BaseModel):
@ -62,6 +37,7 @@ class TestMessage(BaseModel):
reciever_public_key: str reciever_public_key: str
message: str message: str
class TestMessageResponse(BaseModel): class TestMessageResponse(BaseModel):
private_key: str private_key: str
public_key: str public_key: str

View file

View file

@ -26,19 +26,22 @@ from enum import Enum
class Encoding(Enum): class Encoding(Enum):
"""Enumeration type to list the various supported encodings.""" """Enumeration type to list the various supported encodings."""
BECH32 = 1 BECH32 = 1
BECH32M = 2 BECH32M = 2
CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l" CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l"
BECH32M_CONST = 0x2bc830a3 BECH32M_CONST = 0x2BC830A3
def bech32_polymod(values): def bech32_polymod(values):
"""Internal function that computes the Bech32 checksum.""" """Internal function that computes the Bech32 checksum."""
generator = [0x3b6a57b2, 0x26508e6d, 0x1ea119fa, 0x3d4233dd, 0x2a1462b3] generator = [0x3B6A57B2, 0x26508E6D, 0x1EA119FA, 0x3D4233DD, 0x2A1462B3]
chk = 1 chk = 1
for value in values: for value in values:
top = chk >> 25 top = chk >> 25
chk = (chk & 0x1ffffff) << 5 ^ value chk = (chk & 0x1FFFFFF) << 5 ^ value
for i in range(5): for i in range(5):
chk ^= generator[i] if ((top >> i) & 1) else 0 chk ^= generator[i] if ((top >> i) & 1) else 0
return chk return chk
@ -58,6 +61,7 @@ def bech32_verify_checksum(hrp, data):
return Encoding.BECH32M return Encoding.BECH32M
return None return None
def bech32_create_checksum(hrp, data, spec): def bech32_create_checksum(hrp, data, spec):
"""Compute the checksum values given HRP and data.""" """Compute the checksum values given HRP and data."""
values = bech32_hrp_expand(hrp) + data values = bech32_hrp_expand(hrp) + data
@ -69,26 +73,29 @@ def bech32_create_checksum(hrp, data, spec):
def bech32_encode(hrp, data, spec): def bech32_encode(hrp, data, spec):
"""Compute a Bech32 string given HRP and data values.""" """Compute a Bech32 string given HRP and data values."""
combined = data + bech32_create_checksum(hrp, data, spec) combined = data + bech32_create_checksum(hrp, data, spec)
return hrp + '1' + ''.join([CHARSET[d] for d in combined]) return hrp + "1" + "".join([CHARSET[d] for d in combined])
def bech32_decode(bech): def bech32_decode(bech):
"""Validate a Bech32/Bech32m string, and determine HRP and data.""" """Validate a Bech32/Bech32m string, and determine HRP and data."""
if ((any(ord(x) < 33 or ord(x) > 126 for x in bech)) or if (any(ord(x) < 33 or ord(x) > 126 for x in bech)) or (
(bech.lower() != bech and bech.upper() != bech)): bech.lower() != bech and bech.upper() != bech
):
return (None, None, None) return (None, None, None)
bech = bech.lower() bech = bech.lower()
pos = bech.rfind('1') pos = bech.rfind("1")
if pos < 1 or pos + 7 > len(bech) or len(bech) > 90: if pos < 1 or pos + 7 > len(bech) or len(bech) > 90:
return (None, None, None) return (None, None, None)
if not all(x in CHARSET for x in bech[pos+1:]): if not all(x in CHARSET for x in bech[pos + 1 :]):
return (None, None, None) return (None, None, None)
hrp = bech[:pos] hrp = bech[:pos]
data = [CHARSET.find(x) for x in bech[pos+1:]] data = [CHARSET.find(x) for x in bech[pos + 1 :]]
spec = bech32_verify_checksum(hrp, data) spec = bech32_verify_checksum(hrp, data)
if spec is None: if spec is None:
return (None, None, None) return (None, None, None)
return (hrp, data[:-6], spec) return (hrp, data[:-6], spec)
def convertbits(data, frombits, tobits, pad=True): def convertbits(data, frombits, tobits, pad=True):
"""General power-of-2 base conversion.""" """General power-of-2 base conversion."""
acc = 0 acc = 0
@ -124,7 +131,12 @@ def decode(hrp, addr):
return (None, None) return (None, None)
if data[0] == 0 and len(decoded) != 20 and len(decoded) != 32: if data[0] == 0 and len(decoded) != 20 and len(decoded) != 32:
return (None, None) return (None, None)
if data[0] == 0 and spec != Encoding.BECH32 or data[0] != 0 and spec != Encoding.BECH32M: if (
data[0] == 0
and spec != Encoding.BECH32
or data[0] != 0
and spec != Encoding.BECH32M
):
return (None, None) return (None, None)
return (data[0], decoded) return (data[0], decoded)

View file

@ -1,25 +1,36 @@
import asyncio import asyncio
from typing import List
from loguru import logger
from ..relay_manager import RelayManager from ..relay_manager import RelayManager
class NostrClient: class NostrClient:
relays = [ ]
relay_manager = RelayManager() relay_manager = RelayManager()
def __init__(self, relays: List[str] = [], connect=True): def __init__(self):
if len(relays): self.running = True
self.relays = relays
if connect:
self.connect()
async def connect(self): def connect(self, relays):
for relay in self.relays: for relay in relays:
self.relay_manager.add_relay(relay) try:
self.relay_manager.add_relay(relay)
except Exception as e:
logger.debug(e)
self.running = True
def reconnect(self, relays):
self.relay_manager.remove_relays()
self.connect(relays)
def close(self): def close(self):
self.relay_manager.close_connections() try:
self.relay_manager.close_all_subscriptions()
self.relay_manager.close_connections()
self.running = False
except Exception as e:
logger.error(e)
async def subscribe( async def subscribe(
self, self,
@ -27,18 +38,36 @@ class NostrClient:
callback_notices_func=None, callback_notices_func=None,
callback_eosenotices_func=None, callback_eosenotices_func=None,
): ):
while True: while self.running:
self._check_events(callback_events_func)
self._check_notices(callback_notices_func)
self._check_eos_notices(callback_eosenotices_func)
await asyncio.sleep(0.2)
def _check_events(self, callback_events_func=None):
try:
while self.relay_manager.message_pool.has_events(): while self.relay_manager.message_pool.has_events():
event_msg = self.relay_manager.message_pool.get_event() event_msg = self.relay_manager.message_pool.get_event()
if callback_events_func: if callback_events_func:
callback_events_func(event_msg) callback_events_func(event_msg)
except Exception as e:
logger.debug(e)
def _check_notices(self, callback_notices_func=None):
try:
while self.relay_manager.message_pool.has_notices(): while self.relay_manager.message_pool.has_notices():
event_msg = self.relay_manager.message_pool.get_notice() event_msg = self.relay_manager.message_pool.get_notice()
if callback_notices_func: if callback_notices_func:
callback_notices_func(event_msg) callback_notices_func(event_msg)
except Exception as e:
logger.debug(e)
def _check_eos_notices(self, callback_eosenotices_func=None):
try:
while self.relay_manager.message_pool.has_eose_notices(): while self.relay_manager.message_pool.has_eose_notices():
event_msg = self.relay_manager.message_pool.get_eose_notice() event_msg = self.relay_manager.message_pool.get_eose_notice()
if callback_eosenotices_func: if callback_eosenotices_func:
callback_eosenotices_func(event_msg) callback_eosenotices_func(event_msg)
except Exception as e:
await asyncio.sleep(0.5) logger.debug(e)

View file

@ -1,32 +0,0 @@
import time
from dataclasses import dataclass
@dataclass
class Delegation:
delegator_pubkey: str
delegatee_pubkey: str
event_kind: int
duration_secs: int = 30*24*60 # default to 30 days
signature: str = None # set in PrivateKey.sign_delegation
@property
def expires(self) -> int:
return int(time.time()) + self.duration_secs
@property
def conditions(self) -> str:
return f"kind={self.event_kind}&created_at<{self.expires}"
@property
def delegation_token(self) -> str:
return f"nostr:delegation:{self.delegatee_pubkey}:{self.conditions}"
def get_tag(self) -> list[str]:
""" Called by Event """
return [
"delegation",
self.delegator_pubkey,
self.conditions,
self.signature,
]

View file

@ -122,6 +122,7 @@ class EncryptedDirectMessage(Event):
def id(self) -> str: def id(self) -> str:
if self.content is None: if self.content is None:
raise Exception( raise Exception(
"EncryptedDirectMessage `id` is undefined until its message is encrypted and stored in the `content` field" "EncryptedDirectMessage `id` is undefined until its"
+ " message is encrypted and stored in the `content` field"
) )
return super().id return super().id

View file

@ -1,134 +0,0 @@
from collections import UserList
from typing import List
from .event import Event, EventKind
class Filter:
"""
NIP-01 filtering.
Explicitly supports "#e" and "#p" tag filters via `event_refs` and `pubkey_refs`.
Arbitrary NIP-12 single-letter tag filters are also supported via `add_arbitrary_tag`.
If a particular single-letter tag gains prominence, explicit support should be
added. For example:
# arbitrary tag
filter.add_arbitrary_tag('t', [hashtags])
# promoted to explicit support
Filter(hashtag_refs=[hashtags])
"""
def __init__(
self,
event_ids: List[str] = None,
kinds: List[EventKind] = None,
authors: List[str] = None,
since: int = None,
until: int = None,
event_refs: List[
str
] = None, # the "#e" attr; list of event ids referenced in an "e" tag
pubkey_refs: List[
str
] = None, # The "#p" attr; list of pubkeys referenced in a "p" tag
limit: int = None,
) -> None:
self.event_ids = event_ids
self.kinds = kinds
self.authors = authors
self.since = since
self.until = until
self.event_refs = event_refs
self.pubkey_refs = pubkey_refs
self.limit = limit
self.tags = {}
if self.event_refs:
self.add_arbitrary_tag("e", self.event_refs)
if self.pubkey_refs:
self.add_arbitrary_tag("p", self.pubkey_refs)
def add_arbitrary_tag(self, tag: str, values: list):
"""
Filter on any arbitrary tag with explicit handling for NIP-01 and NIP-12
single-letter tags.
"""
# NIP-01 'e' and 'p' tags and any NIP-12 single-letter tags must be prefixed with "#"
tag_key = tag if len(tag) > 1 else f"#{tag}"
self.tags[tag_key] = values
def matches(self, event: Event) -> bool:
if self.event_ids is not None and event.id not in self.event_ids:
return False
if self.kinds is not None and event.kind not in self.kinds:
return False
if self.authors is not None and event.public_key not in self.authors:
return False
if self.since is not None and event.created_at < self.since:
return False
if self.until is not None and event.created_at > self.until:
return False
if (self.event_refs is not None or self.pubkey_refs is not None) and len(
event.tags
) == 0:
return False
if self.tags:
e_tag_identifiers = set([e_tag[0] for e_tag in event.tags])
for f_tag, f_tag_values in self.tags.items():
# Omit any NIP-01 or NIP-12 "#" chars on single-letter tags
f_tag = f_tag.replace("#", "")
if f_tag not in e_tag_identifiers:
# Event is missing a tag type that we're looking for
return False
# Multiple values within f_tag_values are treated as OR search; an Event
# needs to match only one.
# Note: an Event could have multiple entries of the same tag type
# (e.g. a reply to multiple people) so we have to check all of them.
match_found = False
for e_tag in event.tags:
if e_tag[0] == f_tag and e_tag[1] in f_tag_values:
match_found = True
break
if not match_found:
return False
return True
def to_json_object(self) -> dict:
res = {}
if self.event_ids is not None:
res["ids"] = self.event_ids
if self.kinds is not None:
res["kinds"] = self.kinds
if self.authors is not None:
res["authors"] = self.authors
if self.since is not None:
res["since"] = self.since
if self.until is not None:
res["until"] = self.until
if self.limit is not None:
res["limit"] = self.limit
if self.tags:
res.update(self.tags)
return res
class Filters(UserList):
def __init__(self, initlist: "list[Filter]" = []) -> None:
super().__init__(initlist)
self.data: "list[Filter]"
def match(self, event: Event):
for filter in self.data:
if filter.matches(event):
return True
return False
def to_json_array(self) -> list:
return [filter.to_json_object() for filter in self.data]

View file

@ -1,6 +1,5 @@
import base64 import base64
import secrets import secrets
from hashlib import sha256
import secp256k1 import secp256k1
from cffi import FFI from cffi import FFI
@ -8,7 +7,6 @@ from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from . import bech32 from . import bech32
from .delegation import Delegation
from .event import EncryptedDirectMessage, Event, EventKind from .event import EncryptedDirectMessage, Event, EventKind
@ -37,7 +35,7 @@ class PublicKey:
class PrivateKey: class PrivateKey:
def __init__(self, raw_secret: bytes = None) -> None: def __init__(self, raw_secret: bytes = None) -> None:
if not raw_secret is None: if raw_secret is not None:
self.raw_secret = raw_secret self.raw_secret = raw_secret
else: else:
self.raw_secret = secrets.token_bytes(32) self.raw_secret = secrets.token_bytes(32)
@ -79,7 +77,10 @@ class PrivateKey:
encryptor = cipher.encryptor() encryptor = cipher.encryptor()
encrypted_message = encryptor.update(padded_data) + encryptor.finalize() encrypted_message = encryptor.update(padded_data) + encryptor.finalize()
return f"{base64.b64encode(encrypted_message).decode()}?iv={base64.b64encode(iv).decode()}" return (
f"{base64.b64encode(encrypted_message).decode()}"
+ f"?iv={base64.b64encode(iv).decode()}"
)
def encrypt_dm(self, dm: EncryptedDirectMessage) -> None: def encrypt_dm(self, dm: EncryptedDirectMessage) -> None:
dm.content = self.encrypt_message( dm.content = self.encrypt_message(
@ -116,11 +117,6 @@ class PrivateKey:
event.public_key = self.public_key.hex() event.public_key = self.public_key.hex()
event.signature = self.sign_message_hash(bytes.fromhex(event.id)) event.signature = self.sign_message_hash(bytes.fromhex(event.id))
def sign_delegation(self, delegation: Delegation) -> None:
delegation.signature = self.sign_message_hash(
sha256(delegation.delegation_token.encode()).digest()
)
def __eq__(self, other): def __eq__(self, other):
return self.raw_secret == other.raw_secret return self.raw_secret == other.raw_secret

View file

@ -2,13 +2,15 @@ import json
from queue import Queue from queue import Queue
from threading import Lock from threading import Lock
from .event import Event
from .message_type import RelayMessageType from .message_type import RelayMessageType
class EventMessage: class EventMessage:
def __init__(self, event: Event, subscription_id: str, url: str) -> None: def __init__(
self, event: str, event_id: str, subscription_id: str, url: str
) -> None:
self.event = event self.event = event
self.event_id = event_id
self.subscription_id = subscription_id self.subscription_id = subscription_id
self.url = url self.url = url
@ -59,18 +61,16 @@ class MessagePool:
message_type = message_json[0] message_type = message_json[0]
if message_type == RelayMessageType.EVENT: if message_type == RelayMessageType.EVENT:
subscription_id = message_json[1] subscription_id = message_json[1]
e = message_json[2] event = message_json[2]
event = Event( if "id" not in event:
e["content"], return
e["pubkey"], event_id = event["id"]
e["created_at"],
e["kind"],
e["tags"],
e["sig"],
)
with self.lock: with self.lock:
if not f"{subscription_id}_{event.id}" in self._unique_events: if f"{subscription_id}_{event_id}" not in self._unique_events:
self._accept_event(EventMessage(event, subscription_id, url)) self._accept_event(
EventMessage(json.dumps(event), event_id, subscription_id, url)
)
elif message_type == RelayMessageType.NOTICE: elif message_type == RelayMessageType.NOTICE:
self.notices.put(NoticeMessage(message_json[1], url)) self.notices.put(NoticeMessage(message_json[1], url))
elif message_type == RelayMessageType.END_OF_STORED_EVENTS: elif message_type == RelayMessageType.END_OF_STORED_EVENTS:
@ -78,10 +78,12 @@ class MessagePool:
def _accept_event(self, event_message: EventMessage): def _accept_event(self, event_message: EventMessage):
""" """
Event uniqueness is considered per `subscription_id`. Event uniqueness is considered per `subscription_id`. The `subscription_id` is
The `subscription_id` is rewritten to be unique and it is the same accross relays. rewritten to be unique and it is the same accross relays. The same event can
The same event can come from different subscriptions (from the same client or from different ones). come from different subscriptions (from the same client or from different ones).
Clients that have joined later should receive older events. Clients that have joined later should receive older events.
""" """
self.events.put(event_message) self.events.put(event_message)
self._unique_events.add(f"{event_message.subscription_id}_{event_message.event.id}") self._unique_events.add(
f"{event_message.subscription_id}_{event_message.event_id}"
)

View file

@ -2,43 +2,23 @@ import asyncio
import json import json
import time import time
from queue import Queue from queue import Queue
from threading import Lock
from typing import List from typing import List
from loguru import logger from loguru import logger
from websocket import WebSocketApp from websocket import WebSocketApp
from .event import Event
from .filter import Filters
from .message_pool import MessagePool from .message_pool import MessagePool
from .message_type import RelayMessageType
from .subscription import Subscription from .subscription import Subscription
class RelayPolicy:
def __init__(self, should_read: bool = True, should_write: bool = True) -> None:
self.should_read = should_read
self.should_write = should_write
def to_json_object(self) -> dict[str, bool]:
return {"read": self.should_read, "write": self.should_write}
class Relay: class Relay:
def __init__( def __init__(self, url: str, message_pool: MessagePool) -> None:
self,
url: str,
policy: RelayPolicy,
message_pool: MessagePool,
subscriptions: dict[str, Subscription] = {},
) -> None:
self.url = url self.url = url
self.policy = policy
self.message_pool = message_pool self.message_pool = message_pool
self.subscriptions = subscriptions
self.connected: bool = False self.connected: bool = False
self.reconnect: bool = True self.reconnect: bool = True
self.shutdown: bool = False self.shutdown: bool = False
self.error_counter: int = 0 self.error_counter: int = 0
self.error_threshold: int = 100 self.error_threshold: int = 100
self.error_list: List[str] = [] self.error_list: List[str] = []
@ -47,12 +27,10 @@ class Relay:
self.num_received_events: int = 0 self.num_received_events: int = 0
self.num_sent_events: int = 0 self.num_sent_events: int = 0
self.num_subscriptions: int = 0 self.num_subscriptions: int = 0
self.ssl_options: dict = {}
self.proxy: dict = {}
self.lock = Lock()
self.queue = Queue() self.queue = Queue()
def connect(self, ssl_options: dict = None, proxy: dict = None): def connect(self):
self.ws = WebSocketApp( self.ws = WebSocketApp(
self.url, self.url,
on_open=self._on_open, on_open=self._on_open,
@ -62,19 +40,14 @@ class Relay:
on_ping=self._on_ping, on_ping=self._on_ping,
on_pong=self._on_pong, on_pong=self._on_pong,
) )
self.ssl_options = ssl_options
self.proxy = proxy
if not self.connected: if not self.connected:
self.ws.run_forever( self.ws.run_forever(ping_interval=10)
sslopt=ssl_options,
http_proxy_host=None if proxy is None else proxy.get("host"),
http_proxy_port=None if proxy is None else proxy.get("port"),
proxy_type=None if proxy is None else proxy.get("type"),
ping_interval=5,
)
def close(self): def close(self):
self.ws.close() try:
self.ws.close()
except Exception as e:
logger.warning(f"[Relay: {self.url}] Failed to close websocket: {e}")
self.connected = False self.connected = False
self.shutdown = True self.shutdown = True
@ -90,10 +63,9 @@ class Relay:
def publish(self, message: str): def publish(self, message: str):
self.queue.put(message) self.queue.put(message)
def publish_subscriptions(self): def publish_subscriptions(self, subscriptions: List[Subscription] = []):
for _, subscription in self.subscriptions.items(): for s in subscriptions:
s = subscription.to_json_object() json_str = json.dumps(["REQ", s.id] + s.filters)
json_str = json.dumps(["REQ", s["id"], s["filters"][0]])
self.publish(json_str) self.publish(json_str)
async def queue_worker(self): async def queue_worker(self):
@ -103,55 +75,44 @@ class Relay:
message = self.queue.get(timeout=1) message = self.queue.get(timeout=1)
self.num_sent_events += 1 self.num_sent_events += 1
self.ws.send(message) self.ws.send(message)
except: except Exception as _:
pass pass
else: else:
await asyncio.sleep(1) await asyncio.sleep(1)
if self.shutdown:
logger.warning(f"Closing queue worker for '{self.url}'.")
break
def add_subscription(self, id, filters: Filters): if self.shutdown:
with self.lock: logger.warning(f"[Relay: {self.url}] Closing queue worker.")
self.subscriptions[id] = Subscription(id, filters) return
def close_subscription(self, id: str) -> None: def close_subscription(self, id: str) -> None:
with self.lock: try:
self.subscriptions.pop(id)
self.publish(json.dumps(["CLOSE", id])) self.publish(json.dumps(["CLOSE", id]))
except Exception as e:
def to_json_object(self) -> dict: logger.debug(f"[Relay: {self.url}] Failed to close subscription: {e}")
return {
"url": self.url,
"policy": self.policy.to_json_object(),
"subscriptions": [
subscription.to_json_object()
for subscription in self.subscriptions.values()
],
}
def add_notice(self, notice: str): def add_notice(self, notice: str):
self.notice_list = ([notice] + self.notice_list)[:20] self.notice_list = [notice] + self.notice_list
def _on_open(self, _): def _on_open(self, _):
logger.info(f"Connected to relay: '{self.url}'.") logger.info(f"[Relay: {self.url}] Connected.")
self.connected = True self.connected = True
self.shutdown = False
def _on_close(self, _, status_code, message): def _on_close(self, _, status_code, message):
logger.warning(f"Connection to relay {self.url} closed. Status: '{status_code}'. Message: '{message}'.") logger.warning(
f"[Relay: {self.url}] Connection closed."
+ f" Status: '{status_code}'. Message: '{message}'."
)
self.close() self.close()
def _on_message(self, _, message: str): def _on_message(self, _, message: str):
if self._is_valid_message(message): self.num_received_events += 1
self.num_received_events += 1 self.message_pool.add_message(message, self.url)
self.message_pool.add_message(message, self.url)
def _on_error(self, _, error): def _on_error(self, _, error):
logger.warning(f"Relay error: '{str(error)}'") logger.warning(f"[Relay: {self.url}] Error: '{str(error)}'")
self._append_error_message(str(error)) self._append_error_message(str(error))
self.connected = False self.close()
self.error_counter += 1
def _on_ping(self, *_): def _on_ping(self, *_):
return return
@ -159,65 +120,7 @@ class Relay:
def _on_pong(self, *_): def _on_pong(self, *_):
return return
def _is_valid_message(self, message: str) -> bool:
message = message.strip("\n")
if not message or message[0] != "[" or message[-1] != "]":
return False
message_json = json.loads(message)
message_type = message_json[0]
if not RelayMessageType.is_valid(message_type):
return False
if message_type == RelayMessageType.EVENT:
return self._is_valid_event_message(message_json)
if message_type == RelayMessageType.COMMAND_RESULT:
return self._is_valid_command_result_message(message, message_json)
return True
def _is_valid_event_message(self, message_json):
if not len(message_json) == 3:
return False
subscription_id = message_json[1]
with self.lock:
if subscription_id not in self.subscriptions:
return False
e = message_json[2]
event = Event(
e["content"],
e["pubkey"],
e["created_at"],
e["kind"],
e["tags"],
e["sig"],
)
if not event.verify():
return False
with self.lock:
subscription = self.subscriptions[subscription_id]
if subscription.filters and not subscription.filters.match(event):
return False
return True
def _is_valid_command_result_message(self, message, message_json):
if not len(message_json) < 3:
return False
if message_json[2] != True:
logger.warning(f"Relay '{self.url}' negative command result: '{message}'")
self._append_error_message(message)
return False
return True
def _append_error_message(self, message): def _append_error_message(self, message):
self.error_list = ([message] + self.error_list)[:20] self.error_counter += 1
self.last_error_date = int(time.time()) self.error_list = [message] + self.error_list
self.last_error_date = int(time.time())

View file

@ -1,21 +1,15 @@
import asyncio import asyncio
import ssl
import threading import threading
import time import time
from typing import List
from loguru import logger from loguru import logger
from .filter import Filters
from .message_pool import MessagePool, NoticeMessage from .message_pool import MessagePool, NoticeMessage
from .relay import Relay, RelayPolicy from .relay import Relay
from .subscription import Subscription from .subscription import Subscription
class RelayException(Exception):
pass
class RelayManager: class RelayManager:
def __init__(self) -> None: def __init__(self) -> None:
self.relays: dict[str, Relay] = {} self.relays: dict[str, Relay] = {}
@ -25,72 +19,97 @@ class RelayManager:
self._cached_subscriptions: dict[str, Subscription] = {} self._cached_subscriptions: dict[str, Subscription] = {}
self._subscriptions_lock = threading.Lock() self._subscriptions_lock = threading.Lock()
def add_relay(self, url: str, read: bool = True, write: bool = True) -> Relay: def add_relay(self, url: str) -> Relay:
if url in list(self.relays.keys()): if url in list(self.relays.keys()):
return logger.debug(f"Relay '{url}' already present.")
return self.relays[url]
with self._subscriptions_lock:
subscriptions = self._cached_subscriptions.copy()
policy = RelayPolicy(read, write) relay = Relay(url, self.message_pool)
relay = Relay(url, policy, self.message_pool, subscriptions)
self.relays[url] = relay self.relays[url] = relay
self._open_connection( self._open_connection(relay)
relay,
{"cert_reqs": ssl.CERT_NONE}
) # NOTE: This disables ssl certificate verification
relay.publish_subscriptions() relay.publish_subscriptions(list(self._cached_subscriptions.values()))
return relay return relay
def remove_relay(self, url: str): def remove_relay(self, url: str):
self.relays[url].close() try:
self.relays.pop(url) self.relays[url].close()
self.threads[url].join(timeout=5) except Exception as e:
self.threads.pop(url) logger.debug(e)
self.queue_threads[url].join(timeout=5)
self.queue_threads.pop(url)
def add_subscription(self, id: str, filters: Filters): if url in self.relays:
self.relays.pop(url)
try:
self.threads[url].join(timeout=5)
except Exception as e:
logger.debug(e)
if url in self.threads:
self.threads.pop(url)
try:
self.queue_threads[url].join(timeout=5)
except Exception as e:
logger.debug(e)
if url in self.queue_threads:
self.queue_threads.pop(url)
def remove_relays(self):
relay_urls = list(self.relays.keys())
for url in relay_urls:
self.remove_relay(url)
def add_subscription(self, id: str, filters: List[str]):
s = Subscription(id, filters)
with self._subscriptions_lock: with self._subscriptions_lock:
self._cached_subscriptions[id] = Subscription(id, filters) self._cached_subscriptions[id] = s
for relay in self.relays.values(): for relay in self.relays.values():
relay.add_subscription(id, filters) relay.publish_subscriptions([s])
def close_subscription(self, id: str): def close_subscription(self, id: str):
with self._subscriptions_lock: try:
self._cached_subscriptions.pop(id) with self._subscriptions_lock:
if id in self._cached_subscriptions:
self._cached_subscriptions.pop(id)
for relay in self.relays.values(): for relay in self.relays.values():
relay.close_subscription(id) relay.close_subscription(id)
except Exception as e:
logger.debug(e)
def close_subscriptions(self, subscriptions: List[str]):
for id in subscriptions:
self.close_subscription(id)
def close_all_subscriptions(self):
all_subscriptions = list(self._cached_subscriptions.keys())
self.close_subscriptions(all_subscriptions)
def check_and_restart_relays(self): def check_and_restart_relays(self):
stopped_relays = [r for r in self.relays.values() if r.shutdown] stopped_relays = [r for r in self.relays.values() if r.shutdown]
for relay in stopped_relays: for relay in stopped_relays:
self._restart_relay(relay) self._restart_relay(relay)
def close_connections(self): def close_connections(self):
for relay in self.relays.values(): for relay in self.relays.values():
relay.close() relay.close()
def publish_message(self, message: str): def publish_message(self, message: str):
for relay in self.relays.values(): for relay in self.relays.values():
if relay.policy.should_write: relay.publish(message)
relay.publish(message)
def handle_notice(self, notice: NoticeMessage): def handle_notice(self, notice: NoticeMessage):
relay = next((r for r in self.relays.values() if r.url == notice.url)) relay = next((r for r in self.relays.values() if r.url == notice.url))
if relay: if relay:
relay.add_notice(notice.content) relay.add_notice(notice.content)
def _open_connection(self, relay: Relay, ssl_options: dict = None, proxy: dict = None): def _open_connection(self, relay: Relay):
self.threads[relay.url] = threading.Thread( self.threads[relay.url] = threading.Thread(
target=relay.connect, target=relay.connect,
args=(ssl_options, proxy),
name=f"{relay.url}-thread", name=f"{relay.url}-thread",
daemon=True, daemon=True,
) )
@ -98,7 +117,7 @@ class RelayManager:
def wrap_async_queue_worker(): def wrap_async_queue_worker():
asyncio.run(relay.queue_worker()) asyncio.run(relay.queue_worker())
self.queue_threads[relay.url] = threading.Thread( self.queue_threads[relay.url] = threading.Thread(
target=wrap_async_queue_worker, target=wrap_async_queue_worker,
name=f"{relay.url}-queue", name=f"{relay.url}-queue",
@ -108,14 +127,16 @@ class RelayManager:
def _restart_relay(self, relay: Relay): def _restart_relay(self, relay: Relay):
time_since_last_error = time.time() - relay.last_error_date time_since_last_error = time.time() - relay.last_error_date
min_wait_time = min(60 * relay.error_counter, 60 * 60 * 24) # try at least once a day min_wait_time = min(
60 * relay.error_counter, 60 * 60
) # try at least once an hour
if time_since_last_error < min_wait_time: if time_since_last_error < min_wait_time:
return return
logger.info(f"Restarting connection to relay '{relay.url}'") logger.info(f"Restarting connection to relay '{relay.url}'")
self.remove_relay(relay.url) self.remove_relay(relay.url)
new_relay = self.add_relay(relay.url) new_relay = self.add_relay(relay.url)
new_relay.error_counter = relay.error_counter new_relay.error_counter = relay.error_counter
new_relay.error_list = relay.error_list new_relay.error_list = relay.error_list

View file

@ -1,13 +1,7 @@
from .filter import Filters from typing import List
class Subscription: class Subscription:
def __init__(self, id: str, filters: Filters=None) -> None: def __init__(self, id: str, filters: List[str] = None) -> None:
self.id = id self.id = id
self.filters = filters self.filters = filters
def to_json_object(self):
return {
"id": self.id,
"filters": self.filters.to_json_array()
}

178
router.py
View file

@ -1,42 +1,61 @@
import asyncio import asyncio
import json import json
from typing import List, Union from typing import Dict, List
from fastapi import WebSocketDisconnect from fastapi import WebSocket, WebSocketDisconnect
from loguru import logger from loguru import logger
from lnbits.helpers import urlsafe_short_hash from lnbits.helpers import urlsafe_short_hash
from . import nostr from . import nostr_client
from .models import Event, Filter from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage
from .nostr.filter import Filter as NostrFilter
from .nostr.filter import Filters as NostrFilters
from .nostr.message_pool import EndOfStoredEventsMessage, NoticeMessage
class NostrRouter: class NostrRouter:
received_subscription_events: dict[str, List[EventMessage]] = {}
received_subscription_events: dict[str, list[Event]] = {}
received_subscription_notices: list[NoticeMessage] = [] received_subscription_notices: list[NoticeMessage] = []
received_subscription_eosenotices: dict[str, EndOfStoredEventsMessage] = {} received_subscription_eosenotices: dict[str, EndOfStoredEventsMessage] = {}
def __init__(self, websocket): def __init__(self, websocket: WebSocket):
self.subscriptions: List[str] = []
self.connected: bool = True self.connected: bool = True
self.websocket = websocket self.websocket: WebSocket = websocket
self.tasks: List[asyncio.Task] = [] self.tasks: List[asyncio.Task] = []
self.original_subscription_ids = {} self.original_subscription_ids: Dict[str, str] = {}
async def client_to_nostr(self): @property
"""Receives requests / data from the client and forwards it to relays. If the def subscriptions(self) -> List[str]:
request was a subscription/filter, registers it with the nostr client lib. return list(self.original_subscription_ids.keys())
Remembers the subscription id so we can send back responses from the relay to this
client in `nostr_to_client`""" def start(self):
while True: self.connected = True
self.tasks.append(asyncio.create_task(self._client_to_nostr()))
self.tasks.append(asyncio.create_task(self._nostr_to_client()))
async def stop(self):
nostr_client.relay_manager.close_subscriptions(self.subscriptions)
self.connected = False
for t in self.tasks:
try:
t.cancel()
except Exception as _:
pass
try:
await self.websocket.close()
except Exception as _:
pass
async def _client_to_nostr(self):
"""
Receives requests / data from the client and forwards it to relays.
"""
while self.connected:
try: try:
json_str = await self.websocket.receive_text() json_str = await self.websocket.receive_text()
except WebSocketDisconnect: except WebSocketDisconnect as e:
self.connected = False logger.debug(e)
await self.stop()
break break
try: try:
@ -44,15 +63,9 @@ class NostrRouter:
except Exception as e: except Exception as e:
logger.debug(f"Failed to handle client message: '{str(e)}'.") logger.debug(f"Failed to handle client message: '{str(e)}'.")
async def _nostr_to_client(self):
async def nostr_to_client(self): """Sends responses from relays back to the client."""
"""Sends responses from relays back to the client. Polls the subscriptions of this client while self.connected:
stored in `my_subscriptions`. Then gets all responses for this subscription id from `received_subscription_events` which
is filled in tasks.py. Takes one response after the other and relays it back to the client. Reconstructs
the reponse manually because the nostr client lib we're using can't do it. Reconstructs the original subscription id
that we had previously rewritten in order to avoid collisions when multiple clients use the same id.
"""
while True and self.connected:
try: try:
await self._handle_subscriptions() await self._handle_subscriptions()
self._handle_notices() self._handle_notices()
@ -61,24 +74,6 @@ class NostrRouter:
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
async def start(self):
self.tasks.append(asyncio.create_task(self.client_to_nostr()))
self.tasks.append(asyncio.create_task(self.nostr_to_client()))
async def stop(self):
for t in self.tasks:
try:
t.cancel()
except:
pass
for s in self.subscriptions:
try:
nostr.client.relay_manager.close_subscription(s)
except:
pass
self.connected = False
async def _handle_subscriptions(self): async def _handle_subscriptions(self):
for s in self.subscriptions: for s in self.subscriptions:
if s in NostrRouter.received_subscription_events: if s in NostrRouter.received_subscription_events:
@ -86,8 +81,6 @@ class NostrRouter:
if s in NostrRouter.received_subscription_eosenotices: if s in NostrRouter.received_subscription_eosenotices:
await self._handle_received_subscription_eosenotices(s) await self._handle_received_subscription_eosenotices(s)
async def _handle_received_subscription_eosenotices(self, s): async def _handle_received_subscription_eosenotices(self, s):
try: try:
if s not in self.original_subscription_ids: if s not in self.original_subscription_ids:
@ -95,7 +88,7 @@ class NostrRouter:
s_original = self.original_subscription_ids[s] s_original = self.original_subscription_ids[s]
event_to_forward = ["EOSE", s_original] event_to_forward = ["EOSE", s_original]
del NostrRouter.received_subscription_eosenotices[s] del NostrRouter.received_subscription_eosenotices[s]
await self.websocket.send_text(json.dumps(event_to_forward)) await self.websocket.send_text(json.dumps(event_to_forward))
except Exception as e: except Exception as e:
logger.debug(e) logger.debug(e)
@ -104,97 +97,62 @@ class NostrRouter:
try: try:
if s not in NostrRouter.received_subscription_events: if s not in NostrRouter.received_subscription_events:
return return
while len(NostrRouter.received_subscription_events[s]): while len(NostrRouter.received_subscription_events[s]):
my_event = NostrRouter.received_subscription_events[s].pop(0) event_message = NostrRouter.received_subscription_events[s].pop(0)
# event.to_message() does not include the subscription ID, we have to add it manually event_json = event_message.event
event_json = {
"id": my_event.id,
"pubkey": my_event.public_key,
"created_at": my_event.created_at,
"kind": my_event.kind,
"tags": my_event.tags,
"content": my_event.content,
"sig": my_event.signature,
}
# this reconstructs the original response from the relay # this reconstructs the original response from the relay
# reconstruct original subscription id # reconstruct original subscription id
s_original = self.original_subscription_ids[s] s_original = self.original_subscription_ids[s]
event_to_forward = ["EVENT", s_original, event_json] event_to_forward = f"""["EVENT", "{s_original}", {event_json}]"""
await self.websocket.send_text(json.dumps(event_to_forward)) await self.websocket.send_text(event_to_forward)
except Exception as e: except Exception as e:
logger.debug(e) logger.debug(e) # there are 2900 errors here
def _handle_notices(self): def _handle_notices(self):
while len(NostrRouter.received_subscription_notices): while len(NostrRouter.received_subscription_notices):
my_event = NostrRouter.received_subscription_notices.pop(0) my_event = NostrRouter.received_subscription_notices.pop(0)
# note: we don't send it to the user because we don't know who should receive it logger.info(f"[Relay '{my_event.url}'] Notice: '{my_event.content}']")
logger.info(f"Relay ('{my_event.url}') notice: '{my_event.content}']") # Note: we don't send it to the user because
nostr.client.relay_manager.handle_notice(my_event) # we don't know who should receive it
nostr_client.relay_manager.handle_notice(my_event)
def _marshall_nostr_filters(self, data: Union[dict, list]):
filters = data if isinstance(data, list) else [data]
filters = [Filter.parse_obj(f) for f in filters]
filter_list: list[NostrFilter] = []
for filter in filters:
filter_list.append(
NostrFilter(
event_ids=filter.ids, # type: ignore
kinds=filter.kinds, # type: ignore
authors=filter.authors, # type: ignore
since=filter.since, # type: ignore
until=filter.until, # type: ignore
event_refs=filter.e, # type: ignore
pubkey_refs=filter.p, # type: ignore
limit=filter.limit, # type: ignore
)
)
return NostrFilters(filter_list)
async def _handle_client_to_nostr(self, json_str): async def _handle_client_to_nostr(self, json_str):
"""Parses a (string) request from a client. If it is a subscription (REQ) or a CLOSE, it will
register the subscription in the nostr client library that we're using so we can
receive the callbacks on it later. Will rewrite the subscription id since we expect
multiple clients to use the router and want to avoid subscription id collisions
"""
json_data = json.loads(json_str) json_data = json.loads(json_str)
assert len(json_data) assert len(json_data), "Bad JSON array"
if json_data[0] == "REQ": if json_data[0] == "REQ":
self._handle_client_req(json_data) self._handle_client_req(json_data)
return return
if json_data[0] == "CLOSE": if json_data[0] == "CLOSE":
self._handle_client_close(json_data[1]) self._handle_client_close(json_data[1])
return return
if json_data[0] == "EVENT": if json_data[0] == "EVENT":
nostr.client.relay_manager.publish_message(json_str) nostr_client.relay_manager.publish_message(json_str)
return return
def _handle_client_req(self, json_data): def _handle_client_req(self, json_data):
subscription_id = json_data[1] subscription_id = json_data[1]
subscription_id_rewritten = urlsafe_short_hash() subscription_id_rewritten = urlsafe_short_hash()
self.original_subscription_ids[subscription_id_rewritten] = subscription_id self.original_subscription_ids[subscription_id_rewritten] = subscription_id
fltr = json_data[2:] filters = json_data[2:]
filters = self._marshall_nostr_filters(fltr)
nostr.client.relay_manager.add_subscription( nostr_client.relay_manager.add_subscription(subscription_id_rewritten, filters)
subscription_id_rewritten, filters
)
request_rewritten = json.dumps([json_data[0], subscription_id_rewritten] + fltr)
self.subscriptions.append(subscription_id_rewritten)
nostr.client.relay_manager.publish_message(request_rewritten)
def _handle_client_close(self, subscription_id): def _handle_client_close(self, subscription_id):
subscription_id_rewritten = next((k for k, v in self.original_subscription_ids.items() if v == subscription_id), None) subscription_id_rewritten = next(
(
k
for k, v in self.original_subscription_ids.items()
if v == subscription_id
),
None,
)
if subscription_id_rewritten: if subscription_id_rewritten:
self.original_subscription_ids.pop(subscription_id_rewritten) self.original_subscription_ids.pop(subscription_id_rewritten)
nostr.client.relay_manager.close_subscription(subscription_id_rewritten) nostr_client.relay_manager.close_subscription(subscription_id_rewritten)
else: else:
logger.debug(f"Failed to unsubscribe from '{subscription_id}.'") logger.debug(f"Failed to unsubscribe from '{subscription_id}.'")

View file

@ -3,75 +3,69 @@ import threading
from loguru import logger from loguru import logger
from . import nostr from . import nostr_client
from .crud import get_relays from .crud import get_relays
from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage
from .router import NostrRouter, nostr from .router import NostrRouter
async def init_relays(): async def init_relays():
# reinitialize the entire client
nostr.__init__()
# get relays from db # get relays from db
relays = await get_relays() relays = await get_relays()
# set relays and connect to them # set relays and connect to them
nostr.client.relays = list(set([r.url for r in relays.__root__ if r.url])) valid_relays = list(set([r.url for r in relays if r.url]))
await nostr.client.connect()
nostr_client.reconnect(valid_relays)
async def check_relays(): async def check_relays():
""" Check relays that have been disconnected """ """Check relays that have been disconnected"""
while True: while True:
try: try:
await asyncio.sleep(20) await asyncio.sleep(20)
nostr.client.relay_manager.check_and_restart_relays() nostr_client.relay_manager.check_and_restart_relays()
except Exception as e: except Exception as e:
logger.warning(f"Cannot restart relays: '{str(e)}'.") logger.warning(f"Cannot restart relays: '{str(e)}'.")
async def subscribe_events(): async def subscribe_events():
while not any([r.connected for r in nostr.client.relay_manager.relays.values()]): while not any([r.connected for r in nostr_client.relay_manager.relays.values()]):
await asyncio.sleep(2) await asyncio.sleep(2)
def callback_events(eventMessage: EventMessage): def callback_events(eventMessage: EventMessage):
if eventMessage.subscription_id in NostrRouter.received_subscription_events: sub_id = eventMessage.subscription_id
# do not add duplicate events (by event id) if sub_id not in NostrRouter.received_subscription_events:
if eventMessage.event.id in set( NostrRouter.received_subscription_events[sub_id] = [eventMessage]
[ return
e.id
for e in NostrRouter.received_subscription_events[eventMessage.subscription_id]
]
):
return
NostrRouter.received_subscription_events[eventMessage.subscription_id].append( # do not add duplicate events (by event id)
eventMessage.event ids = set(
) [e.event_id for e in NostrRouter.received_subscription_events[sub_id]]
else: )
NostrRouter.received_subscription_events[eventMessage.subscription_id] = [ if eventMessage.event_id in ids:
eventMessage.event return
]
return NostrRouter.received_subscription_events[sub_id].append(eventMessage)
def callback_notices(noticeMessage: NoticeMessage): def callback_notices(noticeMessage: NoticeMessage):
if noticeMessage not in NostrRouter.received_subscription_notices: if noticeMessage not in NostrRouter.received_subscription_notices:
NostrRouter.received_subscription_notices.append(noticeMessage) NostrRouter.received_subscription_notices.append(noticeMessage)
return
def callback_eose_notices(eventMessage: EndOfStoredEventsMessage): def callback_eose_notices(eventMessage: EndOfStoredEventsMessage):
if eventMessage.subscription_id not in NostrRouter.received_subscription_eosenotices: sub_id = eventMessage.subscription_id
NostrRouter.received_subscription_eosenotices[ if sub_id in NostrRouter.received_subscription_eosenotices:
eventMessage.subscription_id return
] = eventMessage
return NostrRouter.received_subscription_eosenotices[sub_id] = eventMessage
def wrap_async_subscribe(): def wrap_async_subscribe():
asyncio.run(nostr.client.subscribe( asyncio.run(
callback_events, nostr_client.subscribe(
callback_notices, callback_events,
callback_eose_notices, callback_notices,
)) callback_eose_notices,
)
)
t = threading.Thread( t = threading.Thread(
target=wrap_async_subscribe, target=wrap_async_subscribe,

View file

@ -6,13 +6,30 @@
<q-form @submit="addRelay"> <q-form @submit="addRelay">
<div class="row q-pa-md"> <div class="row q-pa-md">
<div class="col-9"> <div class="col-9">
<q-input outlined v-model="relayToAdd" dense filled label="Relay URL"></q-input> <q-input
outlined
v-model="relayToAdd"
dense
filled
label="Relay URL"
></q-input>
</div> </div>
<div class="col-3"> <div class="col-3">
<q-btn-dropdown
<q-btn-dropdown unelevated split color="primary" class="float-right" type="submit" label="Add Relay"> unelevated
<q-item v-for="relay in predefinedRelays" :key="relay" @click="addCustomRelay(relay)" clickable split
v-close-popup> color="primary"
class="float-right"
type="submit"
label="Add Relay X"
>
<q-item
v-for="relay in predefinedRelays"
:key="relay"
@click="addCustomRelay(relay)"
clickable
v-close-popup
>
<q-item-section> <q-item-section>
<q-item-label><span v-text="relay"></span></q-item-label> <q-item-label><span v-text="relay"></span></q-item-label>
</q-item-section> </q-item-section>
@ -29,18 +46,36 @@
<h5 class="text-subtitle1 q-my-none">Nostrclient</h5> <h5 class="text-subtitle1 q-my-none">Nostrclient</h5>
</div> </div>
<div class="col-auto"> <div class="col-auto">
<q-input borderless dense debounce="300" v-model="filter" placeholder="Search"> <q-input
borderless
dense
debounce="300"
v-model="filter"
placeholder="Search"
>
<template v-slot:append> <template v-slot:append>
<q-icon name="search"></q-icon> <q-icon name="search"></q-icon>
</template> </template>
</q-input> </q-input>
</div> </div>
</div> </div>
<q-table flat dense :data="nostrrelayLinks" row-key="id" :columns="relayTable.columns" <q-table
:pagination.sync="relayTable.pagination" :filter="filter"> flat
dense
:data="nostrrelayLinks"
row-key="id"
:columns="relayTable.columns"
:pagination.sync="relayTable.pagination"
:filter="filter"
>
<template v-slot:header="props"> <template v-slot:header="props">
<q-tr :props="props"> <q-tr :props="props">
<q-th v-for="col in props.cols" :key="col.name" :props="props" auto-width> <q-th
v-for="col in props.cols"
:key="col.name"
:props="props"
auto-width
>
<div v-if="col.name == 'id'"></div> <div v-if="col.name == 'id'"></div>
<div v-else>{{ col.label }}</div> <div v-else>{{ col.label }}</div>
</q-th> </q-th>
@ -49,29 +84,43 @@
<template v-slot:body="props"> <template v-slot:body="props">
<q-tr :props="props"> <q-tr :props="props">
<q-td v-for="col in props.cols" :key="col.name" :props="props" auto-width> <q-td
v-for="col in props.cols"
:key="col.name"
:props="props"
auto-width
>
<div v-if="col.name == 'connected'"> <div v-if="col.name == 'connected'">
<div v-if="col.value">🟢</div> <div v-if="col.value">🟢</div>
<div v-else> 🔴 </div> <div v-else>🔴</div>
</div> </div>
<div v-else-if="col.name == 'status'"> <div v-else-if="col.name == 'status'">
<div> <div>
⬆️ <span v-text="col.value.sentEvents"></span> ⬆️ <span v-text="col.value.sentEvents"></span> ⬇️
⬇️ <span v-text="col.value.receveidEvents"></span> <span v-text="col.value.receveidEvents"></span>
<span @click="showLogDataDialog(col.value.errorList)" class="cursor-pointer"> <span
⚠️ <span v-text="col.value.errorCount"> @click="showLogDataDialog(col.value.errorList)"
</span> class="cursor-pointer"
>
⚠️ <span v-text="col.value.errorCount"> </span>
</span> </span>
<span @click="showLogDataDialog(col.value.noticeList)" class="cursor-pointer float-right"> <span
@click="showLogDataDialog(col.value.noticeList)"
class="cursor-pointer float-right"
>
</span>
</span> </span>
</div> </div>
</div> </div>
<div v-else-if="col.name == 'delete'"> <div v-else-if="col.name == 'delete'">
<q-btn flat dense size="md" @click="showDeleteRelayDialog(props.row.url)" icon="cancel" <q-btn
color="pink"></q-btn> flat
dense
size="md"
@click="showDeleteRelayDialog(props.row.url)"
icon="cancel"
color="pink"
></q-btn>
</div> </div>
<div v-else> <div v-else>
<div>{{ col.value }}</div> <div>{{ col.value }}</div>
@ -87,15 +136,32 @@
<div class="row"> <div class="row">
<div class="col"> <div class="col">
<div class="text-weight-bold"> <div class="text-weight-bold">
<q-btn flat dense size="0.6rem" class="q-px-none q-mx-none" color="grey" icon="content_copy" <q-btn
@click="copyText(`wss://${host}/nostrclient/api/v1/relay`)"><q-tooltip>Copy address</q-tooltip></q-btn> flat
dense
size="0.6rem"
class="q-px-none q-mx-none"
color="grey"
icon="content_copy"
@click="copyText(`wss://${host}/nostrclient/api/v1/relay`)"
><q-tooltip>Copy address</q-tooltip></q-btn
>
Your endpoint: Your endpoint:
<q-badge outline class="q-ml-sm text-subtitle2" :label="`wss://${host}/nostrclient/api/v1/relay`" /> <q-badge
outline
class="q-ml-sm text-subtitle2"
:label="`wss://${host}/nostrclient/api/v1/relay`"
/>
</div> </div>
</div> </div>
</div> </div>
</q-card-section> </q-card-section>
<q-expansion-item group="advanced" icon="settings" label="Test this endpoint" @click="toggleTestPanel"> <q-expansion-item
group="advanced"
icon="settings"
label="Test this endpoint"
@click="toggleTestPanel"
>
<q-separator></q-separator> <q-separator></q-separator>
<q-card-section> <q-card-section>
<div class="row"> <div class="row">
@ -103,8 +169,13 @@
<span>Sender Private Key:</span> <span>Sender Private Key:</span>
</div> </div>
<div class="col-9"> <div class="col-9">
<q-input outlined v-model="testData.senderPrivateKey" dense filled <q-input
label="Private Key (optional)"></q-input> outlined
v-model="testData.senderPrivateKey"
dense
filled
label="Private Key (optional)"
></q-input>
</div> </div>
</div> </div>
<div class="row q-mt-sm q-mb-lg"> <div class="row q-mt-sm q-mb-lg">
@ -113,7 +184,8 @@
<q-badge color="yellow" text-color="black"> <q-badge color="yellow" text-color="black">
<span> <span>
No not use your real private key! Leave empty for a randomly No not use your real private key! Leave empty for a randomly
generated key.</span> generated key.</span
>
</q-badge> </q-badge>
</div> </div>
</div> </div>
@ -122,7 +194,13 @@
<span>Sender Public Key:</span> <span>Sender Public Key:</span>
</div> </div>
<div class="col-9"> <div class="col-9">
<q-input outlined v-model="testData.senderPublicKey" dense readonly filled></q-input> <q-input
outlined
v-model="testData.senderPublicKey"
dense
readonly
filled
></q-input>
</div> </div>
</div> </div>
<div class="row q-mt-md"> <div class="row q-mt-md">
@ -130,8 +208,15 @@
<span>Test Message:</span> <span>Test Message:</span>
</div> </div>
<div class="col-9"> <div class="col-9">
<q-input outlined v-model="testData.message" dense filled rows="3" type="textarea" <q-input
label="Test Message *"></q-input> outlined
v-model="testData.message"
dense
filled
rows="3"
type="textarea"
label="Test Message *"
></q-input>
</div> </div>
</div> </div>
<div class="row q-mt-md"> <div class="row q-mt-md">
@ -139,22 +224,35 @@
<span>Receiver Public Key:</span> <span>Receiver Public Key:</span>
</div> </div>
<div class="col-9"> <div class="col-9">
<q-input outlined v-model="testData.recieverPublicKey" dense filled <q-input
label="Public Key (hex or npub) *"></q-input> outlined
v-model="testData.recieverPublicKey"
dense
filled
label="Public Key (hex or npub) *"
></q-input>
</div> </div>
</div> </div>
<div class="row q-mt-sm q-mb-lg"> <div class="row q-mt-sm q-mb-lg">
<div class="col-3"></div> <div class="col-3"></div>
<div class="col-9"> <div class="col-9">
<q-badge color="yellow" text-color="black"> <q-badge color="yellow" text-color="black">
<span>This is the recipient of the message. Field required.</span> <span
>This is the recipient of the message. Field required.</span
>
</q-badge> </q-badge>
</div> </div>
</div> </div>
<div class="row"> <div class="row">
<div class="col-12"> <div class="col-12">
<q-btn :disabled="!testData.recieverPublicKey || !testData.message" @click="sendTestMessage" unelevated <q-btn
color="primary" class="float-right">Send Message</q-btn> :disabled="!testData.recieverPublicKey || !testData.message"
@click="sendTestMessage"
unelevated
color="primary"
class="float-right"
>Send Message</q-btn
>
</div> </div>
</div> </div>
</q-card-section> </q-card-section>
@ -166,7 +264,14 @@
<span>Sent Data:</span> <span>Sent Data:</span>
</div> </div>
<div class="col-9"> <div class="col-9">
<q-input outlined v-model="testData.sentData" dense filled rows="5" type="textarea"></q-input> <q-input
outlined
v-model="testData.sentData"
dense
filled
rows="5"
type="textarea"
></q-input>
</div> </div>
</div> </div>
<div class="row q-mt-md"> <div class="row q-mt-md">
@ -174,7 +279,14 @@
<span>Received Data:</span> <span>Received Data:</span>
</div> </div>
<div class="col-9"> <div class="col-9">
<q-input outlined v-model="testData.receivedData" dense filled rows="5" type="textarea"></q-input> <q-input
outlined
v-model="testData.receivedData"
dense
filled
rows="5"
type="textarea"
></q-input>
</div> </div>
</div> </div>
</q-card-section> </q-card-section>
@ -193,8 +305,12 @@
</p> </p>
<p> <p>
<q-badge outline class="q-ml-sm text-subtitle2" color="primary" <q-badge
:label="`wss://${host}/nostrclient/api/v1/relay`" /> outline
class="q-ml-sm text-subtitle2"
color="primary"
:label="`wss://${host}/nostrclient/api/v1/relay`"
/>
</p> </p>
Only Admin users can manage this extension. Only Admin users can manage this extension.
<q-card-section></q-card-section> <q-card-section></q-card-section>
@ -204,14 +320,21 @@
<q-dialog v-model="logData.show" position="top"> <q-dialog v-model="logData.show" position="top">
<q-card class="q-pa-lg q-pt-xl"> <q-card class="q-pa-lg q-pt-xl">
<q-input filled dense v-model.trim="logData.data" type="textarea" rows="25" cols="200" label="Log Data"></q-input> <q-input
filled
dense
v-model.trim="logData.data"
type="textarea"
rows="25"
cols="200"
label="Log Data"
></q-input>
<div class="row q-mt-lg"> <div class="row q-mt-lg">
<q-btn v-close-popup flat color="grey" class="q-ml-auto">Close</q-btn> <q-btn v-close-popup flat color="grey" class="q-ml-auto">Close</q-btn>
</div> </div>
</q-card> </q-card>
</q-dialog> </q-dialog>
</div> </div>
{% endraw %} {% endblock %} {% block scripts %} {{ window_vars(user) }} {% endraw %} {% endblock %} {% block scripts %} {{ window_vars(user) }}
@ -292,8 +415,7 @@
align: 'center', align: 'center',
label: 'Ping', label: 'Ping',
field: 'ping' field: 'ping'
} },
,
{ {
name: 'delete', name: 'delete',
align: 'center', align: 'center',
@ -306,13 +428,13 @@
} }
}, },
predefinedRelays: [ predefinedRelays: [
"wss://relay.damus.io", 'wss://relay.damus.io',
"wss://nostr-pub.wellorder.net", 'wss://nostr-pub.wellorder.net',
"wss://nostr.zebedee.cloud", 'wss://nostr.zebedee.cloud',
"wss://nodestr.fmt.wiz.biz", 'wss://nodestr.fmt.wiz.biz',
"wss://nostr.oxtr.dev", 'wss://nostr.oxtr.dev',
"wss://nostr.wine" 'wss://nostr.wine'
], ]
} }
}, },
methods: { methods: {
@ -355,7 +477,7 @@
'POST', 'POST',
'/nostrclient/api/v1/relay?usr=' + this.g.user.id, '/nostrclient/api/v1/relay?usr=' + this.g.user.id,
this.g.user.wallets[0].adminkey, this.g.user.wallets[0].adminkey,
{ url: this.relayToAdd } {url: this.relayToAdd}
) )
.then(function (response) { .then(function (response) {
console.log('response:', response) console.log('response:', response)
@ -387,15 +509,15 @@
'DELETE', 'DELETE',
'/nostrclient/api/v1/relay?usr=' + this.g.user.id, '/nostrclient/api/v1/relay?usr=' + this.g.user.id,
this.g.user.wallets[0].adminkey, this.g.user.wallets[0].adminkey,
{ url: url } {url: url}
) )
.then((response) => { .then(response => {
const relayIndex = this.nostrrelayLinks.indexOf(r => r.url === url) const relayIndex = this.nostrrelayLinks.indexOf(r => r.url === url)
if (relayIndex !== -1) { if (relayIndex !== -1) {
this.nostrrelayLinks.splice(relayIndex, 1) this.nostrrelayLinks.splice(relayIndex, 1)
} }
}) })
.catch((error) => { .catch(error => {
console.error(error) console.error(error)
LNbits.utils.notifyApiError(error) LNbits.utils.notifyApiError(error)
}) })
@ -437,7 +559,7 @@
}, },
sendTestMessage: async function () { sendTestMessage: async function () {
try { try {
const { data } = await LNbits.api.request( const {data} = await LNbits.api.request(
'PUT', 'PUT',
'/nostrclient/api/v1/relay/test?usr=' + this.g.user.id, '/nostrclient/api/v1/relay/test?usr=' + this.g.user.id,
this.g.user.wallets[0].adminkey, this.g.user.wallets[0].adminkey,
@ -458,7 +580,7 @@
const subscription = JSON.stringify([ const subscription = JSON.stringify([
'REQ', 'REQ',
'test-dms', 'test-dms',
{ kinds: [4], '#p': [event.pubkey] } {kinds: [4], '#p': [event.pubkey]}
]) ])
this.testData.wsConnection.send(subscription) this.testData.wsConnection.send(subscription)
} catch (error) { } catch (error) {
@ -527,4 +649,4 @@
} }
}) })
</script> </script>
{% endblock %} {% endblock %}

View file

@ -1,6 +1,6 @@
import asyncio import asyncio
from http import HTTPStatus from http import HTTPStatus
from typing import Optional from typing import List
from fastapi import Depends, WebSocket from fastapi import Depends, WebSocket
from loguru import logger from loguru import logger
@ -9,23 +9,23 @@ from starlette.exceptions import HTTPException
from lnbits.decorators import check_admin from lnbits.decorators import check_admin
from lnbits.helpers import urlsafe_short_hash from lnbits.helpers import urlsafe_short_hash
from . import nostr, nostrclient_ext, scheduled_tasks from . import nostr_client, nostrclient_ext, scheduled_tasks
from .crud import add_relay, delete_relay, get_relays from .crud import add_relay, delete_relay, get_relays
from .helpers import normalize_public_key from .helpers import normalize_public_key
from .models import Relay, RelayList, TestMessage, TestMessageResponse from .models import Relay, TestMessage, TestMessageResponse
from .nostr.key import EncryptedDirectMessage, PrivateKey from .nostr.key import EncryptedDirectMessage, PrivateKey
from .router import NostrRouter, nostr from .router import NostrRouter
# we keep this in # we keep this in
all_routers: list[NostrRouter] = [] all_routers: list[NostrRouter] = []
@nostrclient_ext.get("/api/v1/relays") @nostrclient_ext.get("/api/v1/relays")
async def api_get_relays() -> RelayList: async def api_get_relays() -> List[Relay]:
relays = RelayList(__root__=[]) relays = []
for url, r in nostr.client.relay_manager.relays.items(): for url, r in nostr_client.relay_manager.relays.items():
relay_id = urlsafe_short_hash() relay_id = urlsafe_short_hash()
relays.__root__.append( relays.append(
Relay( Relay(
id=relay_id, id=relay_id,
url=url, url=url,
@ -47,12 +47,12 @@ async def api_get_relays() -> RelayList:
@nostrclient_ext.post( @nostrclient_ext.post(
"/api/v1/relay", status_code=HTTPStatus.OK, dependencies=[Depends(check_admin)] "/api/v1/relay", status_code=HTTPStatus.OK, dependencies=[Depends(check_admin)]
) )
async def api_add_relay(relay: Relay) -> Optional[RelayList]: async def api_add_relay(relay: Relay) -> List[Relay]:
if not relay.url: if not relay.url:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, detail=f"Relay url not provided." status_code=HTTPStatus.BAD_REQUEST, detail="Relay url not provided."
) )
if relay.url in nostr.client.relay_manager.relays: if relay.url in nostr_client.relay_manager.relays:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
detail=f"Relay: {relay.url} already exists.", detail=f"Relay: {relay.url} already exists.",
@ -60,9 +60,7 @@ async def api_add_relay(relay: Relay) -> Optional[RelayList]:
relay.id = urlsafe_short_hash() relay.id = urlsafe_short_hash()
await add_relay(relay) await add_relay(relay)
nostr.client.relays.append(relay.url) nostr_client.relay_manager.add_relay(relay.url)
nostr.client.relay_manager.add_relay(relay.url)
return await get_relays() return await get_relays()
@ -73,10 +71,10 @@ async def api_add_relay(relay: Relay) -> Optional[RelayList]:
async def api_delete_relay(relay: Relay) -> None: async def api_delete_relay(relay: Relay) -> None:
if not relay.url: if not relay.url:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, detail=f"Relay url not provided." status_code=HTTPStatus.BAD_REQUEST, detail="Relay url not provided."
) )
# we can remove relays during runtime # we can remove relays during runtime
nostr.client.relay_manager.remove_relay(relay.url) nostr_client.relay_manager.remove_relay(relay.url)
await delete_relay(relay) await delete_relay(relay)
@ -88,14 +86,18 @@ async def api_test_endpoint(data: TestMessage) -> TestMessageResponse:
to_public_key = normalize_public_key(data.reciever_public_key) to_public_key = normalize_public_key(data.reciever_public_key)
pk = bytes.fromhex(data.sender_private_key) if data.sender_private_key else None pk = bytes.fromhex(data.sender_private_key) if data.sender_private_key else None
private_key = PrivateKey(pk) private_key = PrivateKey(pk) if pk else PrivateKey()
dm = EncryptedDirectMessage( dm = EncryptedDirectMessage(
recipient_pubkey=to_public_key, cleartext_content=data.message recipient_pubkey=to_public_key, cleartext_content=data.message
) )
private_key.sign_event(dm) private_key.sign_event(dm)
return TestMessageResponse(private_key=private_key.hex(), public_key=to_public_key, event_json=dm.to_message()) return TestMessageResponse(
private_key=private_key.hex(),
public_key=to_public_key,
event_json=dm.to_message(),
)
except (ValueError, AssertionError) as ex: except (ValueError, AssertionError) as ex:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
@ -109,23 +111,18 @@ async def api_test_endpoint(data: TestMessage) -> TestMessageResponse:
) )
@nostrclient_ext.delete( @nostrclient_ext.delete(
"/api/v1", status_code=HTTPStatus.OK, dependencies=[Depends(check_admin)] "/api/v1", status_code=HTTPStatus.OK, dependencies=[Depends(check_admin)]
) )
async def api_stop(): async def api_stop():
for router in all_routers: for router in all_routers:
try: try:
for s in router.subscriptions:
nostr.client.relay_manager.close_subscription(s)
await router.stop() await router.stop()
all_routers.remove(router) all_routers.remove(router)
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
try:
nostr.client.relay_manager.close_connections() nostr_client.close()
except Exception as e:
logger.error(e)
for scheduled_task in scheduled_tasks: for scheduled_task in scheduled_tasks:
try: try:
@ -141,13 +138,14 @@ async def ws_relay(websocket: WebSocket) -> None:
"""Relay multiplexer: one client (per endpoint) <-> multiple relays""" """Relay multiplexer: one client (per endpoint) <-> multiple relays"""
await websocket.accept() await websocket.accept()
router = NostrRouter(websocket) router = NostrRouter(websocket)
await router.start() router.start()
all_routers.append(router) all_routers.append(router)
# we kill this websocket and the subscriptions if the user disconnects and thus `connected==False` # we kill this websocket and the subscriptions
while True: # if the user disconnects and thus `connected==False`
while router.connected:
await asyncio.sleep(10) await asyncio.sleep(10)
if not router.connected:
await router.stop() await router.stop()
all_routers.remove(router) all_routers.remove(router)
break