Stabilize (#24)
* refactor: clean-up * refactor: extra logs plus try-catch * refactor: do not use bare `except` * refactor: clean-up redundant fields * chore: pass code checks * chore: code format * refactor: code clean-up * fix: refactoring stuff * refactor: remove un-used file * chore: code clean-up * chore: code clean-up * chore: code-format fix * refactor: remove nostr.client wrapper * refactor: code clean-up * chore: code format * refactor: remove `RelayList` class * refactor: extract smaller methods with try-catch * fix: better exception handling * fix: remove redundant filters * fix: simplify event * chore: code format * fix: code check * fix: code check * fix: simplify `REQ` * fix: more clean-ups * refactor: use simpler method * refactor: re-order and rename * fix: stop logic * fix: subscription close before disconnect * chore: play commit
This commit is contained in:
parent
ab185bd2c4
commit
16ae9d15a1
20 changed files with 522 additions and 717 deletions
|
|
@ -26,19 +26,22 @@ from enum import Enum
|
|||
|
||||
class Encoding(Enum):
|
||||
"""Enumeration type to list the various supported encodings."""
|
||||
|
||||
BECH32 = 1
|
||||
BECH32M = 2
|
||||
|
||||
|
||||
CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l"
|
||||
BECH32M_CONST = 0x2bc830a3
|
||||
BECH32M_CONST = 0x2BC830A3
|
||||
|
||||
|
||||
def bech32_polymod(values):
|
||||
"""Internal function that computes the Bech32 checksum."""
|
||||
generator = [0x3b6a57b2, 0x26508e6d, 0x1ea119fa, 0x3d4233dd, 0x2a1462b3]
|
||||
generator = [0x3B6A57B2, 0x26508E6D, 0x1EA119FA, 0x3D4233DD, 0x2A1462B3]
|
||||
chk = 1
|
||||
for value in values:
|
||||
top = chk >> 25
|
||||
chk = (chk & 0x1ffffff) << 5 ^ value
|
||||
chk = (chk & 0x1FFFFFF) << 5 ^ value
|
||||
for i in range(5):
|
||||
chk ^= generator[i] if ((top >> i) & 1) else 0
|
||||
return chk
|
||||
|
|
@ -58,6 +61,7 @@ def bech32_verify_checksum(hrp, data):
|
|||
return Encoding.BECH32M
|
||||
return None
|
||||
|
||||
|
||||
def bech32_create_checksum(hrp, data, spec):
|
||||
"""Compute the checksum values given HRP and data."""
|
||||
values = bech32_hrp_expand(hrp) + data
|
||||
|
|
@ -69,26 +73,29 @@ def bech32_create_checksum(hrp, data, spec):
|
|||
def bech32_encode(hrp, data, spec):
|
||||
"""Compute a Bech32 string given HRP and data values."""
|
||||
combined = data + bech32_create_checksum(hrp, data, spec)
|
||||
return hrp + '1' + ''.join([CHARSET[d] for d in combined])
|
||||
return hrp + "1" + "".join([CHARSET[d] for d in combined])
|
||||
|
||||
|
||||
def bech32_decode(bech):
|
||||
"""Validate a Bech32/Bech32m string, and determine HRP and data."""
|
||||
if ((any(ord(x) < 33 or ord(x) > 126 for x in bech)) or
|
||||
(bech.lower() != bech and bech.upper() != bech)):
|
||||
if (any(ord(x) < 33 or ord(x) > 126 for x in bech)) or (
|
||||
bech.lower() != bech and bech.upper() != bech
|
||||
):
|
||||
return (None, None, None)
|
||||
bech = bech.lower()
|
||||
pos = bech.rfind('1')
|
||||
pos = bech.rfind("1")
|
||||
if pos < 1 or pos + 7 > len(bech) or len(bech) > 90:
|
||||
return (None, None, None)
|
||||
if not all(x in CHARSET for x in bech[pos+1:]):
|
||||
if not all(x in CHARSET for x in bech[pos + 1 :]):
|
||||
return (None, None, None)
|
||||
hrp = bech[:pos]
|
||||
data = [CHARSET.find(x) for x in bech[pos+1:]]
|
||||
data = [CHARSET.find(x) for x in bech[pos + 1 :]]
|
||||
spec = bech32_verify_checksum(hrp, data)
|
||||
if spec is None:
|
||||
return (None, None, None)
|
||||
return (hrp, data[:-6], spec)
|
||||
|
||||
|
||||
def convertbits(data, frombits, tobits, pad=True):
|
||||
"""General power-of-2 base conversion."""
|
||||
acc = 0
|
||||
|
|
@ -124,7 +131,12 @@ def decode(hrp, addr):
|
|||
return (None, None)
|
||||
if data[0] == 0 and len(decoded) != 20 and len(decoded) != 32:
|
||||
return (None, None)
|
||||
if data[0] == 0 and spec != Encoding.BECH32 or data[0] != 0 and spec != Encoding.BECH32M:
|
||||
if (
|
||||
data[0] == 0
|
||||
and spec != Encoding.BECH32
|
||||
or data[0] != 0
|
||||
and spec != Encoding.BECH32M
|
||||
):
|
||||
return (None, None)
|
||||
return (data[0], decoded)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,25 +1,36 @@
|
|||
import asyncio
|
||||
from typing import List
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from ..relay_manager import RelayManager
|
||||
|
||||
|
||||
class NostrClient:
|
||||
relays = [ ]
|
||||
relay_manager = RelayManager()
|
||||
|
||||
def __init__(self, relays: List[str] = [], connect=True):
|
||||
if len(relays):
|
||||
self.relays = relays
|
||||
if connect:
|
||||
self.connect()
|
||||
def __init__(self):
|
||||
self.running = True
|
||||
|
||||
async def connect(self):
|
||||
for relay in self.relays:
|
||||
self.relay_manager.add_relay(relay)
|
||||
def connect(self, relays):
|
||||
for relay in relays:
|
||||
try:
|
||||
self.relay_manager.add_relay(relay)
|
||||
except Exception as e:
|
||||
logger.debug(e)
|
||||
self.running = True
|
||||
|
||||
def reconnect(self, relays):
|
||||
self.relay_manager.remove_relays()
|
||||
self.connect(relays)
|
||||
|
||||
def close(self):
|
||||
self.relay_manager.close_connections()
|
||||
try:
|
||||
self.relay_manager.close_all_subscriptions()
|
||||
self.relay_manager.close_connections()
|
||||
|
||||
self.running = False
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
async def subscribe(
|
||||
self,
|
||||
|
|
@ -27,18 +38,36 @@ class NostrClient:
|
|||
callback_notices_func=None,
|
||||
callback_eosenotices_func=None,
|
||||
):
|
||||
while True:
|
||||
while self.running:
|
||||
self._check_events(callback_events_func)
|
||||
self._check_notices(callback_notices_func)
|
||||
self._check_eos_notices(callback_eosenotices_func)
|
||||
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
def _check_events(self, callback_events_func=None):
|
||||
try:
|
||||
while self.relay_manager.message_pool.has_events():
|
||||
event_msg = self.relay_manager.message_pool.get_event()
|
||||
if callback_events_func:
|
||||
callback_events_func(event_msg)
|
||||
except Exception as e:
|
||||
logger.debug(e)
|
||||
|
||||
def _check_notices(self, callback_notices_func=None):
|
||||
try:
|
||||
while self.relay_manager.message_pool.has_notices():
|
||||
event_msg = self.relay_manager.message_pool.get_notice()
|
||||
if callback_notices_func:
|
||||
callback_notices_func(event_msg)
|
||||
except Exception as e:
|
||||
logger.debug(e)
|
||||
|
||||
def _check_eos_notices(self, callback_eosenotices_func=None):
|
||||
try:
|
||||
while self.relay_manager.message_pool.has_eose_notices():
|
||||
event_msg = self.relay_manager.message_pool.get_eose_notice()
|
||||
if callback_eosenotices_func:
|
||||
callback_eosenotices_func(event_msg)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
except Exception as e:
|
||||
logger.debug(e)
|
||||
|
|
|
|||
|
|
@ -1,32 +0,0 @@
|
|||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Delegation:
|
||||
delegator_pubkey: str
|
||||
delegatee_pubkey: str
|
||||
event_kind: int
|
||||
duration_secs: int = 30*24*60 # default to 30 days
|
||||
signature: str = None # set in PrivateKey.sign_delegation
|
||||
|
||||
@property
|
||||
def expires(self) -> int:
|
||||
return int(time.time()) + self.duration_secs
|
||||
|
||||
@property
|
||||
def conditions(self) -> str:
|
||||
return f"kind={self.event_kind}&created_at<{self.expires}"
|
||||
|
||||
@property
|
||||
def delegation_token(self) -> str:
|
||||
return f"nostr:delegation:{self.delegatee_pubkey}:{self.conditions}"
|
||||
|
||||
def get_tag(self) -> list[str]:
|
||||
""" Called by Event """
|
||||
return [
|
||||
"delegation",
|
||||
self.delegator_pubkey,
|
||||
self.conditions,
|
||||
self.signature,
|
||||
]
|
||||
|
|
@ -122,6 +122,7 @@ class EncryptedDirectMessage(Event):
|
|||
def id(self) -> str:
|
||||
if self.content is None:
|
||||
raise Exception(
|
||||
"EncryptedDirectMessage `id` is undefined until its message is encrypted and stored in the `content` field"
|
||||
"EncryptedDirectMessage `id` is undefined until its"
|
||||
+ " message is encrypted and stored in the `content` field"
|
||||
)
|
||||
return super().id
|
||||
|
|
|
|||
134
nostr/filter.py
134
nostr/filter.py
|
|
@ -1,134 +0,0 @@
|
|||
from collections import UserList
|
||||
from typing import List
|
||||
|
||||
from .event import Event, EventKind
|
||||
|
||||
|
||||
class Filter:
|
||||
"""
|
||||
NIP-01 filtering.
|
||||
|
||||
Explicitly supports "#e" and "#p" tag filters via `event_refs` and `pubkey_refs`.
|
||||
|
||||
Arbitrary NIP-12 single-letter tag filters are also supported via `add_arbitrary_tag`.
|
||||
If a particular single-letter tag gains prominence, explicit support should be
|
||||
added. For example:
|
||||
# arbitrary tag
|
||||
filter.add_arbitrary_tag('t', [hashtags])
|
||||
|
||||
# promoted to explicit support
|
||||
Filter(hashtag_refs=[hashtags])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_ids: List[str] = None,
|
||||
kinds: List[EventKind] = None,
|
||||
authors: List[str] = None,
|
||||
since: int = None,
|
||||
until: int = None,
|
||||
event_refs: List[
|
||||
str
|
||||
] = None, # the "#e" attr; list of event ids referenced in an "e" tag
|
||||
pubkey_refs: List[
|
||||
str
|
||||
] = None, # The "#p" attr; list of pubkeys referenced in a "p" tag
|
||||
limit: int = None,
|
||||
) -> None:
|
||||
self.event_ids = event_ids
|
||||
self.kinds = kinds
|
||||
self.authors = authors
|
||||
self.since = since
|
||||
self.until = until
|
||||
self.event_refs = event_refs
|
||||
self.pubkey_refs = pubkey_refs
|
||||
self.limit = limit
|
||||
|
||||
self.tags = {}
|
||||
if self.event_refs:
|
||||
self.add_arbitrary_tag("e", self.event_refs)
|
||||
if self.pubkey_refs:
|
||||
self.add_arbitrary_tag("p", self.pubkey_refs)
|
||||
|
||||
def add_arbitrary_tag(self, tag: str, values: list):
|
||||
"""
|
||||
Filter on any arbitrary tag with explicit handling for NIP-01 and NIP-12
|
||||
single-letter tags.
|
||||
"""
|
||||
# NIP-01 'e' and 'p' tags and any NIP-12 single-letter tags must be prefixed with "#"
|
||||
tag_key = tag if len(tag) > 1 else f"#{tag}"
|
||||
self.tags[tag_key] = values
|
||||
|
||||
def matches(self, event: Event) -> bool:
|
||||
if self.event_ids is not None and event.id not in self.event_ids:
|
||||
return False
|
||||
if self.kinds is not None and event.kind not in self.kinds:
|
||||
return False
|
||||
if self.authors is not None and event.public_key not in self.authors:
|
||||
return False
|
||||
if self.since is not None and event.created_at < self.since:
|
||||
return False
|
||||
if self.until is not None and event.created_at > self.until:
|
||||
return False
|
||||
if (self.event_refs is not None or self.pubkey_refs is not None) and len(
|
||||
event.tags
|
||||
) == 0:
|
||||
return False
|
||||
|
||||
if self.tags:
|
||||
e_tag_identifiers = set([e_tag[0] for e_tag in event.tags])
|
||||
for f_tag, f_tag_values in self.tags.items():
|
||||
# Omit any NIP-01 or NIP-12 "#" chars on single-letter tags
|
||||
f_tag = f_tag.replace("#", "")
|
||||
|
||||
if f_tag not in e_tag_identifiers:
|
||||
# Event is missing a tag type that we're looking for
|
||||
return False
|
||||
|
||||
# Multiple values within f_tag_values are treated as OR search; an Event
|
||||
# needs to match only one.
|
||||
# Note: an Event could have multiple entries of the same tag type
|
||||
# (e.g. a reply to multiple people) so we have to check all of them.
|
||||
match_found = False
|
||||
for e_tag in event.tags:
|
||||
if e_tag[0] == f_tag and e_tag[1] in f_tag_values:
|
||||
match_found = True
|
||||
break
|
||||
if not match_found:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def to_json_object(self) -> dict:
|
||||
res = {}
|
||||
if self.event_ids is not None:
|
||||
res["ids"] = self.event_ids
|
||||
if self.kinds is not None:
|
||||
res["kinds"] = self.kinds
|
||||
if self.authors is not None:
|
||||
res["authors"] = self.authors
|
||||
if self.since is not None:
|
||||
res["since"] = self.since
|
||||
if self.until is not None:
|
||||
res["until"] = self.until
|
||||
if self.limit is not None:
|
||||
res["limit"] = self.limit
|
||||
if self.tags:
|
||||
res.update(self.tags)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class Filters(UserList):
|
||||
def __init__(self, initlist: "list[Filter]" = []) -> None:
|
||||
super().__init__(initlist)
|
||||
self.data: "list[Filter]"
|
||||
|
||||
def match(self, event: Event):
|
||||
for filter in self.data:
|
||||
if filter.matches(event):
|
||||
return True
|
||||
return False
|
||||
|
||||
def to_json_array(self) -> list:
|
||||
return [filter.to_json_object() for filter in self.data]
|
||||
14
nostr/key.py
14
nostr/key.py
|
|
@ -1,6 +1,5 @@
|
|||
import base64
|
||||
import secrets
|
||||
from hashlib import sha256
|
||||
|
||||
import secp256k1
|
||||
from cffi import FFI
|
||||
|
|
@ -8,7 +7,6 @@ from cryptography.hazmat.primitives import padding
|
|||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
|
||||
from . import bech32
|
||||
from .delegation import Delegation
|
||||
from .event import EncryptedDirectMessage, Event, EventKind
|
||||
|
||||
|
||||
|
|
@ -37,7 +35,7 @@ class PublicKey:
|
|||
|
||||
class PrivateKey:
|
||||
def __init__(self, raw_secret: bytes = None) -> None:
|
||||
if not raw_secret is None:
|
||||
if raw_secret is not None:
|
||||
self.raw_secret = raw_secret
|
||||
else:
|
||||
self.raw_secret = secrets.token_bytes(32)
|
||||
|
|
@ -79,7 +77,10 @@ class PrivateKey:
|
|||
encryptor = cipher.encryptor()
|
||||
encrypted_message = encryptor.update(padded_data) + encryptor.finalize()
|
||||
|
||||
return f"{base64.b64encode(encrypted_message).decode()}?iv={base64.b64encode(iv).decode()}"
|
||||
return (
|
||||
f"{base64.b64encode(encrypted_message).decode()}"
|
||||
+ f"?iv={base64.b64encode(iv).decode()}"
|
||||
)
|
||||
|
||||
def encrypt_dm(self, dm: EncryptedDirectMessage) -> None:
|
||||
dm.content = self.encrypt_message(
|
||||
|
|
@ -116,11 +117,6 @@ class PrivateKey:
|
|||
event.public_key = self.public_key.hex()
|
||||
event.signature = self.sign_message_hash(bytes.fromhex(event.id))
|
||||
|
||||
def sign_delegation(self, delegation: Delegation) -> None:
|
||||
delegation.signature = self.sign_message_hash(
|
||||
sha256(delegation.delegation_token.encode()).digest()
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.raw_secret == other.raw_secret
|
||||
|
||||
|
|
|
|||
|
|
@ -2,13 +2,15 @@ import json
|
|||
from queue import Queue
|
||||
from threading import Lock
|
||||
|
||||
from .event import Event
|
||||
from .message_type import RelayMessageType
|
||||
|
||||
|
||||
class EventMessage:
|
||||
def __init__(self, event: Event, subscription_id: str, url: str) -> None:
|
||||
def __init__(
|
||||
self, event: str, event_id: str, subscription_id: str, url: str
|
||||
) -> None:
|
||||
self.event = event
|
||||
self.event_id = event_id
|
||||
self.subscription_id = subscription_id
|
||||
self.url = url
|
||||
|
||||
|
|
@ -59,18 +61,16 @@ class MessagePool:
|
|||
message_type = message_json[0]
|
||||
if message_type == RelayMessageType.EVENT:
|
||||
subscription_id = message_json[1]
|
||||
e = message_json[2]
|
||||
event = Event(
|
||||
e["content"],
|
||||
e["pubkey"],
|
||||
e["created_at"],
|
||||
e["kind"],
|
||||
e["tags"],
|
||||
e["sig"],
|
||||
)
|
||||
event = message_json[2]
|
||||
if "id" not in event:
|
||||
return
|
||||
event_id = event["id"]
|
||||
|
||||
with self.lock:
|
||||
if not f"{subscription_id}_{event.id}" in self._unique_events:
|
||||
self._accept_event(EventMessage(event, subscription_id, url))
|
||||
if f"{subscription_id}_{event_id}" not in self._unique_events:
|
||||
self._accept_event(
|
||||
EventMessage(json.dumps(event), event_id, subscription_id, url)
|
||||
)
|
||||
elif message_type == RelayMessageType.NOTICE:
|
||||
self.notices.put(NoticeMessage(message_json[1], url))
|
||||
elif message_type == RelayMessageType.END_OF_STORED_EVENTS:
|
||||
|
|
@ -78,10 +78,12 @@ class MessagePool:
|
|||
|
||||
def _accept_event(self, event_message: EventMessage):
|
||||
"""
|
||||
Event uniqueness is considered per `subscription_id`.
|
||||
The `subscription_id` is rewritten to be unique and it is the same accross relays.
|
||||
The same event can come from different subscriptions (from the same client or from different ones).
|
||||
Clients that have joined later should receive older events.
|
||||
Event uniqueness is considered per `subscription_id`. The `subscription_id` is
|
||||
rewritten to be unique and it is the same accross relays. The same event can
|
||||
come from different subscriptions (from the same client or from different ones).
|
||||
Clients that have joined later should receive older events.
|
||||
"""
|
||||
self.events.put(event_message)
|
||||
self._unique_events.add(f"{event_message.subscription_id}_{event_message.event.id}")
|
||||
self._unique_events.add(
|
||||
f"{event_message.subscription_id}_{event_message.event_id}"
|
||||
)
|
||||
|
|
|
|||
165
nostr/relay.py
165
nostr/relay.py
|
|
@ -2,43 +2,23 @@ import asyncio
|
|||
import json
|
||||
import time
|
||||
from queue import Queue
|
||||
from threading import Lock
|
||||
from typing import List
|
||||
|
||||
from loguru import logger
|
||||
from websocket import WebSocketApp
|
||||
|
||||
from .event import Event
|
||||
from .filter import Filters
|
||||
from .message_pool import MessagePool
|
||||
from .message_type import RelayMessageType
|
||||
from .subscription import Subscription
|
||||
|
||||
|
||||
class RelayPolicy:
|
||||
def __init__(self, should_read: bool = True, should_write: bool = True) -> None:
|
||||
self.should_read = should_read
|
||||
self.should_write = should_write
|
||||
|
||||
def to_json_object(self) -> dict[str, bool]:
|
||||
return {"read": self.should_read, "write": self.should_write}
|
||||
|
||||
|
||||
class Relay:
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
policy: RelayPolicy,
|
||||
message_pool: MessagePool,
|
||||
subscriptions: dict[str, Subscription] = {},
|
||||
) -> None:
|
||||
def __init__(self, url: str, message_pool: MessagePool) -> None:
|
||||
self.url = url
|
||||
self.policy = policy
|
||||
self.message_pool = message_pool
|
||||
self.subscriptions = subscriptions
|
||||
self.connected: bool = False
|
||||
self.reconnect: bool = True
|
||||
self.shutdown: bool = False
|
||||
|
||||
self.error_counter: int = 0
|
||||
self.error_threshold: int = 100
|
||||
self.error_list: List[str] = []
|
||||
|
|
@ -47,12 +27,10 @@ class Relay:
|
|||
self.num_received_events: int = 0
|
||||
self.num_sent_events: int = 0
|
||||
self.num_subscriptions: int = 0
|
||||
self.ssl_options: dict = {}
|
||||
self.proxy: dict = {}
|
||||
self.lock = Lock()
|
||||
|
||||
self.queue = Queue()
|
||||
|
||||
def connect(self, ssl_options: dict = None, proxy: dict = None):
|
||||
def connect(self):
|
||||
self.ws = WebSocketApp(
|
||||
self.url,
|
||||
on_open=self._on_open,
|
||||
|
|
@ -62,19 +40,14 @@ class Relay:
|
|||
on_ping=self._on_ping,
|
||||
on_pong=self._on_pong,
|
||||
)
|
||||
self.ssl_options = ssl_options
|
||||
self.proxy = proxy
|
||||
if not self.connected:
|
||||
self.ws.run_forever(
|
||||
sslopt=ssl_options,
|
||||
http_proxy_host=None if proxy is None else proxy.get("host"),
|
||||
http_proxy_port=None if proxy is None else proxy.get("port"),
|
||||
proxy_type=None if proxy is None else proxy.get("type"),
|
||||
ping_interval=5,
|
||||
)
|
||||
self.ws.run_forever(ping_interval=10)
|
||||
|
||||
def close(self):
|
||||
self.ws.close()
|
||||
try:
|
||||
self.ws.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"[Relay: {self.url}] Failed to close websocket: {e}")
|
||||
self.connected = False
|
||||
self.shutdown = True
|
||||
|
||||
|
|
@ -90,10 +63,9 @@ class Relay:
|
|||
def publish(self, message: str):
|
||||
self.queue.put(message)
|
||||
|
||||
def publish_subscriptions(self):
|
||||
for _, subscription in self.subscriptions.items():
|
||||
s = subscription.to_json_object()
|
||||
json_str = json.dumps(["REQ", s["id"], s["filters"][0]])
|
||||
def publish_subscriptions(self, subscriptions: List[Subscription] = []):
|
||||
for s in subscriptions:
|
||||
json_str = json.dumps(["REQ", s.id] + s.filters)
|
||||
self.publish(json_str)
|
||||
|
||||
async def queue_worker(self):
|
||||
|
|
@ -103,55 +75,44 @@ class Relay:
|
|||
message = self.queue.get(timeout=1)
|
||||
self.num_sent_events += 1
|
||||
self.ws.send(message)
|
||||
except:
|
||||
except Exception as _:
|
||||
pass
|
||||
else:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if self.shutdown:
|
||||
logger.warning(f"Closing queue worker for '{self.url}'.")
|
||||
break
|
||||
|
||||
def add_subscription(self, id, filters: Filters):
|
||||
with self.lock:
|
||||
self.subscriptions[id] = Subscription(id, filters)
|
||||
if self.shutdown:
|
||||
logger.warning(f"[Relay: {self.url}] Closing queue worker.")
|
||||
return
|
||||
|
||||
def close_subscription(self, id: str) -> None:
|
||||
with self.lock:
|
||||
self.subscriptions.pop(id)
|
||||
try:
|
||||
self.publish(json.dumps(["CLOSE", id]))
|
||||
|
||||
def to_json_object(self) -> dict:
|
||||
return {
|
||||
"url": self.url,
|
||||
"policy": self.policy.to_json_object(),
|
||||
"subscriptions": [
|
||||
subscription.to_json_object()
|
||||
for subscription in self.subscriptions.values()
|
||||
],
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug(f"[Relay: {self.url}] Failed to close subscription: {e}")
|
||||
|
||||
def add_notice(self, notice: str):
|
||||
self.notice_list = ([notice] + self.notice_list)[:20]
|
||||
self.notice_list = [notice] + self.notice_list
|
||||
|
||||
def _on_open(self, _):
|
||||
logger.info(f"Connected to relay: '{self.url}'.")
|
||||
logger.info(f"[Relay: {self.url}] Connected.")
|
||||
self.connected = True
|
||||
|
||||
self.shutdown = False
|
||||
|
||||
def _on_close(self, _, status_code, message):
|
||||
logger.warning(f"Connection to relay {self.url} closed. Status: '{status_code}'. Message: '{message}'.")
|
||||
logger.warning(
|
||||
f"[Relay: {self.url}] Connection closed."
|
||||
+ f" Status: '{status_code}'. Message: '{message}'."
|
||||
)
|
||||
self.close()
|
||||
|
||||
def _on_message(self, _, message: str):
|
||||
if self._is_valid_message(message):
|
||||
self.num_received_events += 1
|
||||
self.message_pool.add_message(message, self.url)
|
||||
self.num_received_events += 1
|
||||
self.message_pool.add_message(message, self.url)
|
||||
|
||||
def _on_error(self, _, error):
|
||||
logger.warning(f"Relay error: '{str(error)}'")
|
||||
logger.warning(f"[Relay: {self.url}] Error: '{str(error)}'")
|
||||
self._append_error_message(str(error))
|
||||
self.connected = False
|
||||
self.error_counter += 1
|
||||
self.close()
|
||||
|
||||
def _on_ping(self, *_):
|
||||
return
|
||||
|
|
@ -159,65 +120,7 @@ class Relay:
|
|||
def _on_pong(self, *_):
|
||||
return
|
||||
|
||||
def _is_valid_message(self, message: str) -> bool:
|
||||
message = message.strip("\n")
|
||||
if not message or message[0] != "[" or message[-1] != "]":
|
||||
return False
|
||||
|
||||
message_json = json.loads(message)
|
||||
message_type = message_json[0]
|
||||
|
||||
if not RelayMessageType.is_valid(message_type):
|
||||
return False
|
||||
|
||||
if message_type == RelayMessageType.EVENT:
|
||||
return self._is_valid_event_message(message_json)
|
||||
|
||||
if message_type == RelayMessageType.COMMAND_RESULT:
|
||||
return self._is_valid_command_result_message(message, message_json)
|
||||
|
||||
return True
|
||||
|
||||
def _is_valid_event_message(self, message_json):
|
||||
if not len(message_json) == 3:
|
||||
return False
|
||||
|
||||
subscription_id = message_json[1]
|
||||
with self.lock:
|
||||
if subscription_id not in self.subscriptions:
|
||||
return False
|
||||
|
||||
e = message_json[2]
|
||||
event = Event(
|
||||
e["content"],
|
||||
e["pubkey"],
|
||||
e["created_at"],
|
||||
e["kind"],
|
||||
e["tags"],
|
||||
e["sig"],
|
||||
)
|
||||
if not event.verify():
|
||||
return False
|
||||
|
||||
with self.lock:
|
||||
subscription = self.subscriptions[subscription_id]
|
||||
|
||||
if subscription.filters and not subscription.filters.match(event):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _is_valid_command_result_message(self, message, message_json):
|
||||
if not len(message_json) < 3:
|
||||
return False
|
||||
|
||||
if message_json[2] != True:
|
||||
logger.warning(f"Relay '{self.url}' negative command result: '{message}'")
|
||||
self._append_error_message(message)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _append_error_message(self, message):
|
||||
self.error_list = ([message] + self.error_list)[:20]
|
||||
self.last_error_date = int(time.time())
|
||||
self.error_counter += 1
|
||||
self.error_list = [message] + self.error_list
|
||||
self.last_error_date = int(time.time())
|
||||
|
|
|
|||
|
|
@ -1,21 +1,15 @@
|
|||
|
||||
import asyncio
|
||||
import ssl
|
||||
import threading
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from .filter import Filters
|
||||
from .message_pool import MessagePool, NoticeMessage
|
||||
from .relay import Relay, RelayPolicy
|
||||
from .relay import Relay
|
||||
from .subscription import Subscription
|
||||
|
||||
|
||||
class RelayException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class RelayManager:
|
||||
def __init__(self) -> None:
|
||||
self.relays: dict[str, Relay] = {}
|
||||
|
|
@ -25,72 +19,97 @@ class RelayManager:
|
|||
self._cached_subscriptions: dict[str, Subscription] = {}
|
||||
self._subscriptions_lock = threading.Lock()
|
||||
|
||||
def add_relay(self, url: str, read: bool = True, write: bool = True) -> Relay:
|
||||
def add_relay(self, url: str) -> Relay:
|
||||
if url in list(self.relays.keys()):
|
||||
return
|
||||
|
||||
with self._subscriptions_lock:
|
||||
subscriptions = self._cached_subscriptions.copy()
|
||||
logger.debug(f"Relay '{url}' already present.")
|
||||
return self.relays[url]
|
||||
|
||||
policy = RelayPolicy(read, write)
|
||||
relay = Relay(url, policy, self.message_pool, subscriptions)
|
||||
relay = Relay(url, self.message_pool)
|
||||
self.relays[url] = relay
|
||||
|
||||
self._open_connection(
|
||||
relay,
|
||||
{"cert_reqs": ssl.CERT_NONE}
|
||||
) # NOTE: This disables ssl certificate verification
|
||||
self._open_connection(relay)
|
||||
|
||||
relay.publish_subscriptions()
|
||||
relay.publish_subscriptions(list(self._cached_subscriptions.values()))
|
||||
return relay
|
||||
|
||||
def remove_relay(self, url: str):
|
||||
self.relays[url].close()
|
||||
self.relays.pop(url)
|
||||
self.threads[url].join(timeout=5)
|
||||
self.threads.pop(url)
|
||||
self.queue_threads[url].join(timeout=5)
|
||||
self.queue_threads.pop(url)
|
||||
|
||||
try:
|
||||
self.relays[url].close()
|
||||
except Exception as e:
|
||||
logger.debug(e)
|
||||
|
||||
def add_subscription(self, id: str, filters: Filters):
|
||||
if url in self.relays:
|
||||
self.relays.pop(url)
|
||||
|
||||
try:
|
||||
self.threads[url].join(timeout=5)
|
||||
except Exception as e:
|
||||
logger.debug(e)
|
||||
|
||||
if url in self.threads:
|
||||
self.threads.pop(url)
|
||||
|
||||
try:
|
||||
self.queue_threads[url].join(timeout=5)
|
||||
except Exception as e:
|
||||
logger.debug(e)
|
||||
|
||||
if url in self.queue_threads:
|
||||
self.queue_threads.pop(url)
|
||||
|
||||
def remove_relays(self):
|
||||
relay_urls = list(self.relays.keys())
|
||||
for url in relay_urls:
|
||||
self.remove_relay(url)
|
||||
|
||||
def add_subscription(self, id: str, filters: List[str]):
|
||||
s = Subscription(id, filters)
|
||||
with self._subscriptions_lock:
|
||||
self._cached_subscriptions[id] = Subscription(id, filters)
|
||||
self._cached_subscriptions[id] = s
|
||||
|
||||
for relay in self.relays.values():
|
||||
relay.add_subscription(id, filters)
|
||||
relay.publish_subscriptions([s])
|
||||
|
||||
def close_subscription(self, id: str):
|
||||
with self._subscriptions_lock:
|
||||
self._cached_subscriptions.pop(id)
|
||||
try:
|
||||
with self._subscriptions_lock:
|
||||
if id in self._cached_subscriptions:
|
||||
self._cached_subscriptions.pop(id)
|
||||
|
||||
for relay in self.relays.values():
|
||||
relay.close_subscription(id)
|
||||
for relay in self.relays.values():
|
||||
relay.close_subscription(id)
|
||||
except Exception as e:
|
||||
logger.debug(e)
|
||||
|
||||
def close_subscriptions(self, subscriptions: List[str]):
|
||||
for id in subscriptions:
|
||||
self.close_subscription(id)
|
||||
|
||||
def close_all_subscriptions(self):
|
||||
all_subscriptions = list(self._cached_subscriptions.keys())
|
||||
self.close_subscriptions(all_subscriptions)
|
||||
|
||||
def check_and_restart_relays(self):
|
||||
stopped_relays = [r for r in self.relays.values() if r.shutdown]
|
||||
for relay in stopped_relays:
|
||||
self._restart_relay(relay)
|
||||
|
||||
|
||||
def close_connections(self):
|
||||
for relay in self.relays.values():
|
||||
relay.close()
|
||||
|
||||
def publish_message(self, message: str):
|
||||
for relay in self.relays.values():
|
||||
if relay.policy.should_write:
|
||||
relay.publish(message)
|
||||
relay.publish(message)
|
||||
|
||||
def handle_notice(self, notice: NoticeMessage):
|
||||
relay = next((r for r in self.relays.values() if r.url == notice.url))
|
||||
if relay:
|
||||
relay.add_notice(notice.content)
|
||||
|
||||
def _open_connection(self, relay: Relay, ssl_options: dict = None, proxy: dict = None):
|
||||
def _open_connection(self, relay: Relay):
|
||||
self.threads[relay.url] = threading.Thread(
|
||||
target=relay.connect,
|
||||
args=(ssl_options, proxy),
|
||||
name=f"{relay.url}-thread",
|
||||
daemon=True,
|
||||
)
|
||||
|
|
@ -98,7 +117,7 @@ class RelayManager:
|
|||
|
||||
def wrap_async_queue_worker():
|
||||
asyncio.run(relay.queue_worker())
|
||||
|
||||
|
||||
self.queue_threads[relay.url] = threading.Thread(
|
||||
target=wrap_async_queue_worker,
|
||||
name=f"{relay.url}-queue",
|
||||
|
|
@ -108,14 +127,16 @@ class RelayManager:
|
|||
|
||||
def _restart_relay(self, relay: Relay):
|
||||
time_since_last_error = time.time() - relay.last_error_date
|
||||
|
||||
min_wait_time = min(60 * relay.error_counter, 60 * 60 * 24) # try at least once a day
|
||||
|
||||
min_wait_time = min(
|
||||
60 * relay.error_counter, 60 * 60
|
||||
) # try at least once an hour
|
||||
if time_since_last_error < min_wait_time:
|
||||
return
|
||||
|
||||
|
||||
logger.info(f"Restarting connection to relay '{relay.url}'")
|
||||
|
||||
self.remove_relay(relay.url)
|
||||
new_relay = self.add_relay(relay.url)
|
||||
new_relay.error_counter = relay.error_counter
|
||||
new_relay.error_list = relay.error_list
|
||||
new_relay.error_list = relay.error_list
|
||||
|
|
|
|||
|
|
@ -1,13 +1,7 @@
|
|||
from .filter import Filters
|
||||
from typing import List
|
||||
|
||||
|
||||
class Subscription:
|
||||
def __init__(self, id: str, filters: Filters=None) -> None:
|
||||
def __init__(self, id: str, filters: List[str] = None) -> None:
|
||||
self.id = id
|
||||
self.filters = filters
|
||||
|
||||
def to_json_object(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"filters": self.filters.to_json_array()
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue