Stabilize (#24)

* refactor: clean-up

* refactor: extra logs plus try-catch

* refactor: do not use bare `except`

* refactor: clean-up redundant fields

* chore: pass code checks

* chore: code format

* refactor: code clean-up

* fix: refactoring stuff

* refactor: remove un-used file

* chore: code clean-up

* chore: code clean-up

* chore: code-format fix

* refactor: remove nostr.client wrapper

* refactor: code clean-up

* chore: code format

* refactor: remove `RelayList` class

* refactor: extract smaller methods with try-catch

* fix: better exception handling

* fix: remove redundant filters

* fix: simplify event

* chore: code format

* fix: code check

* fix: code check

* fix: simplify `REQ`

* fix: more clean-ups

* refactor: use simpler method

* refactor: re-order and rename

* fix: stop logic

* fix: subscription close before disconnect

* chore: play commit
This commit is contained in:
Vlad Stan 2023-11-01 17:46:42 +02:00 committed by GitHub
parent ab185bd2c4
commit 16ae9d15a1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 522 additions and 717 deletions

View file

View file

@ -26,19 +26,22 @@ from enum import Enum
class Encoding(Enum):
"""Enumeration type to list the various supported encodings."""
BECH32 = 1
BECH32M = 2
CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l"
BECH32M_CONST = 0x2bc830a3
BECH32M_CONST = 0x2BC830A3
def bech32_polymod(values):
"""Internal function that computes the Bech32 checksum."""
generator = [0x3b6a57b2, 0x26508e6d, 0x1ea119fa, 0x3d4233dd, 0x2a1462b3]
generator = [0x3B6A57B2, 0x26508E6D, 0x1EA119FA, 0x3D4233DD, 0x2A1462B3]
chk = 1
for value in values:
top = chk >> 25
chk = (chk & 0x1ffffff) << 5 ^ value
chk = (chk & 0x1FFFFFF) << 5 ^ value
for i in range(5):
chk ^= generator[i] if ((top >> i) & 1) else 0
return chk
@ -58,6 +61,7 @@ def bech32_verify_checksum(hrp, data):
return Encoding.BECH32M
return None
def bech32_create_checksum(hrp, data, spec):
"""Compute the checksum values given HRP and data."""
values = bech32_hrp_expand(hrp) + data
@ -69,26 +73,29 @@ def bech32_create_checksum(hrp, data, spec):
def bech32_encode(hrp, data, spec):
"""Compute a Bech32 string given HRP and data values."""
combined = data + bech32_create_checksum(hrp, data, spec)
return hrp + '1' + ''.join([CHARSET[d] for d in combined])
return hrp + "1" + "".join([CHARSET[d] for d in combined])
def bech32_decode(bech):
"""Validate a Bech32/Bech32m string, and determine HRP and data."""
if ((any(ord(x) < 33 or ord(x) > 126 for x in bech)) or
(bech.lower() != bech and bech.upper() != bech)):
if (any(ord(x) < 33 or ord(x) > 126 for x in bech)) or (
bech.lower() != bech and bech.upper() != bech
):
return (None, None, None)
bech = bech.lower()
pos = bech.rfind('1')
pos = bech.rfind("1")
if pos < 1 or pos + 7 > len(bech) or len(bech) > 90:
return (None, None, None)
if not all(x in CHARSET for x in bech[pos+1:]):
if not all(x in CHARSET for x in bech[pos + 1 :]):
return (None, None, None)
hrp = bech[:pos]
data = [CHARSET.find(x) for x in bech[pos+1:]]
data = [CHARSET.find(x) for x in bech[pos + 1 :]]
spec = bech32_verify_checksum(hrp, data)
if spec is None:
return (None, None, None)
return (hrp, data[:-6], spec)
def convertbits(data, frombits, tobits, pad=True):
"""General power-of-2 base conversion."""
acc = 0
@ -124,7 +131,12 @@ def decode(hrp, addr):
return (None, None)
if data[0] == 0 and len(decoded) != 20 and len(decoded) != 32:
return (None, None)
if data[0] == 0 and spec != Encoding.BECH32 or data[0] != 0 and spec != Encoding.BECH32M:
if (
data[0] == 0
and spec != Encoding.BECH32
or data[0] != 0
and spec != Encoding.BECH32M
):
return (None, None)
return (data[0], decoded)

View file

@ -1,25 +1,36 @@
import asyncio
from typing import List
from loguru import logger
from ..relay_manager import RelayManager
class NostrClient:
relays = [ ]
relay_manager = RelayManager()
def __init__(self, relays: List[str] = [], connect=True):
if len(relays):
self.relays = relays
if connect:
self.connect()
def __init__(self):
self.running = True
async def connect(self):
for relay in self.relays:
self.relay_manager.add_relay(relay)
def connect(self, relays):
for relay in relays:
try:
self.relay_manager.add_relay(relay)
except Exception as e:
logger.debug(e)
self.running = True
def reconnect(self, relays):
self.relay_manager.remove_relays()
self.connect(relays)
def close(self):
self.relay_manager.close_connections()
try:
self.relay_manager.close_all_subscriptions()
self.relay_manager.close_connections()
self.running = False
except Exception as e:
logger.error(e)
async def subscribe(
self,
@ -27,18 +38,36 @@ class NostrClient:
callback_notices_func=None,
callback_eosenotices_func=None,
):
while True:
while self.running:
self._check_events(callback_events_func)
self._check_notices(callback_notices_func)
self._check_eos_notices(callback_eosenotices_func)
await asyncio.sleep(0.2)
def _check_events(self, callback_events_func=None):
try:
while self.relay_manager.message_pool.has_events():
event_msg = self.relay_manager.message_pool.get_event()
if callback_events_func:
callback_events_func(event_msg)
except Exception as e:
logger.debug(e)
def _check_notices(self, callback_notices_func=None):
try:
while self.relay_manager.message_pool.has_notices():
event_msg = self.relay_manager.message_pool.get_notice()
if callback_notices_func:
callback_notices_func(event_msg)
except Exception as e:
logger.debug(e)
def _check_eos_notices(self, callback_eosenotices_func=None):
try:
while self.relay_manager.message_pool.has_eose_notices():
event_msg = self.relay_manager.message_pool.get_eose_notice()
if callback_eosenotices_func:
callback_eosenotices_func(event_msg)
await asyncio.sleep(0.5)
except Exception as e:
logger.debug(e)

View file

@ -1,32 +0,0 @@
import time
from dataclasses import dataclass
@dataclass
class Delegation:
delegator_pubkey: str
delegatee_pubkey: str
event_kind: int
duration_secs: int = 30*24*60 # default to 30 days
signature: str = None # set in PrivateKey.sign_delegation
@property
def expires(self) -> int:
return int(time.time()) + self.duration_secs
@property
def conditions(self) -> str:
return f"kind={self.event_kind}&created_at<{self.expires}"
@property
def delegation_token(self) -> str:
return f"nostr:delegation:{self.delegatee_pubkey}:{self.conditions}"
def get_tag(self) -> list[str]:
""" Called by Event """
return [
"delegation",
self.delegator_pubkey,
self.conditions,
self.signature,
]

View file

@ -122,6 +122,7 @@ class EncryptedDirectMessage(Event):
def id(self) -> str:
if self.content is None:
raise Exception(
"EncryptedDirectMessage `id` is undefined until its message is encrypted and stored in the `content` field"
"EncryptedDirectMessage `id` is undefined until its"
+ " message is encrypted and stored in the `content` field"
)
return super().id

View file

@ -1,134 +0,0 @@
from collections import UserList
from typing import List
from .event import Event, EventKind
class Filter:
"""
NIP-01 filtering.
Explicitly supports "#e" and "#p" tag filters via `event_refs` and `pubkey_refs`.
Arbitrary NIP-12 single-letter tag filters are also supported via `add_arbitrary_tag`.
If a particular single-letter tag gains prominence, explicit support should be
added. For example:
# arbitrary tag
filter.add_arbitrary_tag('t', [hashtags])
# promoted to explicit support
Filter(hashtag_refs=[hashtags])
"""
def __init__(
self,
event_ids: List[str] = None,
kinds: List[EventKind] = None,
authors: List[str] = None,
since: int = None,
until: int = None,
event_refs: List[
str
] = None, # the "#e" attr; list of event ids referenced in an "e" tag
pubkey_refs: List[
str
] = None, # The "#p" attr; list of pubkeys referenced in a "p" tag
limit: int = None,
) -> None:
self.event_ids = event_ids
self.kinds = kinds
self.authors = authors
self.since = since
self.until = until
self.event_refs = event_refs
self.pubkey_refs = pubkey_refs
self.limit = limit
self.tags = {}
if self.event_refs:
self.add_arbitrary_tag("e", self.event_refs)
if self.pubkey_refs:
self.add_arbitrary_tag("p", self.pubkey_refs)
def add_arbitrary_tag(self, tag: str, values: list):
"""
Filter on any arbitrary tag with explicit handling for NIP-01 and NIP-12
single-letter tags.
"""
# NIP-01 'e' and 'p' tags and any NIP-12 single-letter tags must be prefixed with "#"
tag_key = tag if len(tag) > 1 else f"#{tag}"
self.tags[tag_key] = values
def matches(self, event: Event) -> bool:
if self.event_ids is not None and event.id not in self.event_ids:
return False
if self.kinds is not None and event.kind not in self.kinds:
return False
if self.authors is not None and event.public_key not in self.authors:
return False
if self.since is not None and event.created_at < self.since:
return False
if self.until is not None and event.created_at > self.until:
return False
if (self.event_refs is not None or self.pubkey_refs is not None) and len(
event.tags
) == 0:
return False
if self.tags:
e_tag_identifiers = set([e_tag[0] for e_tag in event.tags])
for f_tag, f_tag_values in self.tags.items():
# Omit any NIP-01 or NIP-12 "#" chars on single-letter tags
f_tag = f_tag.replace("#", "")
if f_tag not in e_tag_identifiers:
# Event is missing a tag type that we're looking for
return False
# Multiple values within f_tag_values are treated as OR search; an Event
# needs to match only one.
# Note: an Event could have multiple entries of the same tag type
# (e.g. a reply to multiple people) so we have to check all of them.
match_found = False
for e_tag in event.tags:
if e_tag[0] == f_tag and e_tag[1] in f_tag_values:
match_found = True
break
if not match_found:
return False
return True
def to_json_object(self) -> dict:
res = {}
if self.event_ids is not None:
res["ids"] = self.event_ids
if self.kinds is not None:
res["kinds"] = self.kinds
if self.authors is not None:
res["authors"] = self.authors
if self.since is not None:
res["since"] = self.since
if self.until is not None:
res["until"] = self.until
if self.limit is not None:
res["limit"] = self.limit
if self.tags:
res.update(self.tags)
return res
class Filters(UserList):
def __init__(self, initlist: "list[Filter]" = []) -> None:
super().__init__(initlist)
self.data: "list[Filter]"
def match(self, event: Event):
for filter in self.data:
if filter.matches(event):
return True
return False
def to_json_array(self) -> list:
return [filter.to_json_object() for filter in self.data]

View file

@ -1,6 +1,5 @@
import base64
import secrets
from hashlib import sha256
import secp256k1
from cffi import FFI
@ -8,7 +7,6 @@ from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from . import bech32
from .delegation import Delegation
from .event import EncryptedDirectMessage, Event, EventKind
@ -37,7 +35,7 @@ class PublicKey:
class PrivateKey:
def __init__(self, raw_secret: bytes = None) -> None:
if not raw_secret is None:
if raw_secret is not None:
self.raw_secret = raw_secret
else:
self.raw_secret = secrets.token_bytes(32)
@ -79,7 +77,10 @@ class PrivateKey:
encryptor = cipher.encryptor()
encrypted_message = encryptor.update(padded_data) + encryptor.finalize()
return f"{base64.b64encode(encrypted_message).decode()}?iv={base64.b64encode(iv).decode()}"
return (
f"{base64.b64encode(encrypted_message).decode()}"
+ f"?iv={base64.b64encode(iv).decode()}"
)
def encrypt_dm(self, dm: EncryptedDirectMessage) -> None:
dm.content = self.encrypt_message(
@ -116,11 +117,6 @@ class PrivateKey:
event.public_key = self.public_key.hex()
event.signature = self.sign_message_hash(bytes.fromhex(event.id))
def sign_delegation(self, delegation: Delegation) -> None:
delegation.signature = self.sign_message_hash(
sha256(delegation.delegation_token.encode()).digest()
)
def __eq__(self, other):
return self.raw_secret == other.raw_secret

View file

@ -2,13 +2,15 @@ import json
from queue import Queue
from threading import Lock
from .event import Event
from .message_type import RelayMessageType
class EventMessage:
def __init__(self, event: Event, subscription_id: str, url: str) -> None:
def __init__(
self, event: str, event_id: str, subscription_id: str, url: str
) -> None:
self.event = event
self.event_id = event_id
self.subscription_id = subscription_id
self.url = url
@ -59,18 +61,16 @@ class MessagePool:
message_type = message_json[0]
if message_type == RelayMessageType.EVENT:
subscription_id = message_json[1]
e = message_json[2]
event = Event(
e["content"],
e["pubkey"],
e["created_at"],
e["kind"],
e["tags"],
e["sig"],
)
event = message_json[2]
if "id" not in event:
return
event_id = event["id"]
with self.lock:
if not f"{subscription_id}_{event.id}" in self._unique_events:
self._accept_event(EventMessage(event, subscription_id, url))
if f"{subscription_id}_{event_id}" not in self._unique_events:
self._accept_event(
EventMessage(json.dumps(event), event_id, subscription_id, url)
)
elif message_type == RelayMessageType.NOTICE:
self.notices.put(NoticeMessage(message_json[1], url))
elif message_type == RelayMessageType.END_OF_STORED_EVENTS:
@ -78,10 +78,12 @@ class MessagePool:
def _accept_event(self, event_message: EventMessage):
"""
Event uniqueness is considered per `subscription_id`.
The `subscription_id` is rewritten to be unique and it is the same accross relays.
The same event can come from different subscriptions (from the same client or from different ones).
Clients that have joined later should receive older events.
Event uniqueness is considered per `subscription_id`. The `subscription_id` is
rewritten to be unique and it is the same accross relays. The same event can
come from different subscriptions (from the same client or from different ones).
Clients that have joined later should receive older events.
"""
self.events.put(event_message)
self._unique_events.add(f"{event_message.subscription_id}_{event_message.event.id}")
self._unique_events.add(
f"{event_message.subscription_id}_{event_message.event_id}"
)

View file

@ -2,43 +2,23 @@ import asyncio
import json
import time
from queue import Queue
from threading import Lock
from typing import List
from loguru import logger
from websocket import WebSocketApp
from .event import Event
from .filter import Filters
from .message_pool import MessagePool
from .message_type import RelayMessageType
from .subscription import Subscription
class RelayPolicy:
def __init__(self, should_read: bool = True, should_write: bool = True) -> None:
self.should_read = should_read
self.should_write = should_write
def to_json_object(self) -> dict[str, bool]:
return {"read": self.should_read, "write": self.should_write}
class Relay:
def __init__(
self,
url: str,
policy: RelayPolicy,
message_pool: MessagePool,
subscriptions: dict[str, Subscription] = {},
) -> None:
def __init__(self, url: str, message_pool: MessagePool) -> None:
self.url = url
self.policy = policy
self.message_pool = message_pool
self.subscriptions = subscriptions
self.connected: bool = False
self.reconnect: bool = True
self.shutdown: bool = False
self.error_counter: int = 0
self.error_threshold: int = 100
self.error_list: List[str] = []
@ -47,12 +27,10 @@ class Relay:
self.num_received_events: int = 0
self.num_sent_events: int = 0
self.num_subscriptions: int = 0
self.ssl_options: dict = {}
self.proxy: dict = {}
self.lock = Lock()
self.queue = Queue()
def connect(self, ssl_options: dict = None, proxy: dict = None):
def connect(self):
self.ws = WebSocketApp(
self.url,
on_open=self._on_open,
@ -62,19 +40,14 @@ class Relay:
on_ping=self._on_ping,
on_pong=self._on_pong,
)
self.ssl_options = ssl_options
self.proxy = proxy
if not self.connected:
self.ws.run_forever(
sslopt=ssl_options,
http_proxy_host=None if proxy is None else proxy.get("host"),
http_proxy_port=None if proxy is None else proxy.get("port"),
proxy_type=None if proxy is None else proxy.get("type"),
ping_interval=5,
)
self.ws.run_forever(ping_interval=10)
def close(self):
self.ws.close()
try:
self.ws.close()
except Exception as e:
logger.warning(f"[Relay: {self.url}] Failed to close websocket: {e}")
self.connected = False
self.shutdown = True
@ -90,10 +63,9 @@ class Relay:
def publish(self, message: str):
self.queue.put(message)
def publish_subscriptions(self):
for _, subscription in self.subscriptions.items():
s = subscription.to_json_object()
json_str = json.dumps(["REQ", s["id"], s["filters"][0]])
def publish_subscriptions(self, subscriptions: List[Subscription] = []):
for s in subscriptions:
json_str = json.dumps(["REQ", s.id] + s.filters)
self.publish(json_str)
async def queue_worker(self):
@ -103,55 +75,44 @@ class Relay:
message = self.queue.get(timeout=1)
self.num_sent_events += 1
self.ws.send(message)
except:
except Exception as _:
pass
else:
await asyncio.sleep(1)
if self.shutdown:
logger.warning(f"Closing queue worker for '{self.url}'.")
break
def add_subscription(self, id, filters: Filters):
with self.lock:
self.subscriptions[id] = Subscription(id, filters)
if self.shutdown:
logger.warning(f"[Relay: {self.url}] Closing queue worker.")
return
def close_subscription(self, id: str) -> None:
with self.lock:
self.subscriptions.pop(id)
try:
self.publish(json.dumps(["CLOSE", id]))
def to_json_object(self) -> dict:
return {
"url": self.url,
"policy": self.policy.to_json_object(),
"subscriptions": [
subscription.to_json_object()
for subscription in self.subscriptions.values()
],
}
except Exception as e:
logger.debug(f"[Relay: {self.url}] Failed to close subscription: {e}")
def add_notice(self, notice: str):
self.notice_list = ([notice] + self.notice_list)[:20]
self.notice_list = [notice] + self.notice_list
def _on_open(self, _):
logger.info(f"Connected to relay: '{self.url}'.")
logger.info(f"[Relay: {self.url}] Connected.")
self.connected = True
self.shutdown = False
def _on_close(self, _, status_code, message):
logger.warning(f"Connection to relay {self.url} closed. Status: '{status_code}'. Message: '{message}'.")
logger.warning(
f"[Relay: {self.url}] Connection closed."
+ f" Status: '{status_code}'. Message: '{message}'."
)
self.close()
def _on_message(self, _, message: str):
if self._is_valid_message(message):
self.num_received_events += 1
self.message_pool.add_message(message, self.url)
self.num_received_events += 1
self.message_pool.add_message(message, self.url)
def _on_error(self, _, error):
logger.warning(f"Relay error: '{str(error)}'")
logger.warning(f"[Relay: {self.url}] Error: '{str(error)}'")
self._append_error_message(str(error))
self.connected = False
self.error_counter += 1
self.close()
def _on_ping(self, *_):
return
@ -159,65 +120,7 @@ class Relay:
def _on_pong(self, *_):
return
def _is_valid_message(self, message: str) -> bool:
message = message.strip("\n")
if not message or message[0] != "[" or message[-1] != "]":
return False
message_json = json.loads(message)
message_type = message_json[0]
if not RelayMessageType.is_valid(message_type):
return False
if message_type == RelayMessageType.EVENT:
return self._is_valid_event_message(message_json)
if message_type == RelayMessageType.COMMAND_RESULT:
return self._is_valid_command_result_message(message, message_json)
return True
def _is_valid_event_message(self, message_json):
if not len(message_json) == 3:
return False
subscription_id = message_json[1]
with self.lock:
if subscription_id not in self.subscriptions:
return False
e = message_json[2]
event = Event(
e["content"],
e["pubkey"],
e["created_at"],
e["kind"],
e["tags"],
e["sig"],
)
if not event.verify():
return False
with self.lock:
subscription = self.subscriptions[subscription_id]
if subscription.filters and not subscription.filters.match(event):
return False
return True
def _is_valid_command_result_message(self, message, message_json):
if not len(message_json) < 3:
return False
if message_json[2] != True:
logger.warning(f"Relay '{self.url}' negative command result: '{message}'")
self._append_error_message(message)
return False
return True
def _append_error_message(self, message):
self.error_list = ([message] + self.error_list)[:20]
self.last_error_date = int(time.time())
self.error_counter += 1
self.error_list = [message] + self.error_list
self.last_error_date = int(time.time())

View file

@ -1,21 +1,15 @@
import asyncio
import ssl
import threading
import time
from typing import List
from loguru import logger
from .filter import Filters
from .message_pool import MessagePool, NoticeMessage
from .relay import Relay, RelayPolicy
from .relay import Relay
from .subscription import Subscription
class RelayException(Exception):
pass
class RelayManager:
def __init__(self) -> None:
self.relays: dict[str, Relay] = {}
@ -25,72 +19,97 @@ class RelayManager:
self._cached_subscriptions: dict[str, Subscription] = {}
self._subscriptions_lock = threading.Lock()
def add_relay(self, url: str, read: bool = True, write: bool = True) -> Relay:
def add_relay(self, url: str) -> Relay:
if url in list(self.relays.keys()):
return
with self._subscriptions_lock:
subscriptions = self._cached_subscriptions.copy()
logger.debug(f"Relay '{url}' already present.")
return self.relays[url]
policy = RelayPolicy(read, write)
relay = Relay(url, policy, self.message_pool, subscriptions)
relay = Relay(url, self.message_pool)
self.relays[url] = relay
self._open_connection(
relay,
{"cert_reqs": ssl.CERT_NONE}
) # NOTE: This disables ssl certificate verification
self._open_connection(relay)
relay.publish_subscriptions()
relay.publish_subscriptions(list(self._cached_subscriptions.values()))
return relay
def remove_relay(self, url: str):
self.relays[url].close()
self.relays.pop(url)
self.threads[url].join(timeout=5)
self.threads.pop(url)
self.queue_threads[url].join(timeout=5)
self.queue_threads.pop(url)
try:
self.relays[url].close()
except Exception as e:
logger.debug(e)
def add_subscription(self, id: str, filters: Filters):
if url in self.relays:
self.relays.pop(url)
try:
self.threads[url].join(timeout=5)
except Exception as e:
logger.debug(e)
if url in self.threads:
self.threads.pop(url)
try:
self.queue_threads[url].join(timeout=5)
except Exception as e:
logger.debug(e)
if url in self.queue_threads:
self.queue_threads.pop(url)
def remove_relays(self):
relay_urls = list(self.relays.keys())
for url in relay_urls:
self.remove_relay(url)
def add_subscription(self, id: str, filters: List[str]):
s = Subscription(id, filters)
with self._subscriptions_lock:
self._cached_subscriptions[id] = Subscription(id, filters)
self._cached_subscriptions[id] = s
for relay in self.relays.values():
relay.add_subscription(id, filters)
relay.publish_subscriptions([s])
def close_subscription(self, id: str):
with self._subscriptions_lock:
self._cached_subscriptions.pop(id)
try:
with self._subscriptions_lock:
if id in self._cached_subscriptions:
self._cached_subscriptions.pop(id)
for relay in self.relays.values():
relay.close_subscription(id)
for relay in self.relays.values():
relay.close_subscription(id)
except Exception as e:
logger.debug(e)
def close_subscriptions(self, subscriptions: List[str]):
for id in subscriptions:
self.close_subscription(id)
def close_all_subscriptions(self):
all_subscriptions = list(self._cached_subscriptions.keys())
self.close_subscriptions(all_subscriptions)
def check_and_restart_relays(self):
stopped_relays = [r for r in self.relays.values() if r.shutdown]
for relay in stopped_relays:
self._restart_relay(relay)
def close_connections(self):
for relay in self.relays.values():
relay.close()
def publish_message(self, message: str):
for relay in self.relays.values():
if relay.policy.should_write:
relay.publish(message)
relay.publish(message)
def handle_notice(self, notice: NoticeMessage):
relay = next((r for r in self.relays.values() if r.url == notice.url))
if relay:
relay.add_notice(notice.content)
def _open_connection(self, relay: Relay, ssl_options: dict = None, proxy: dict = None):
def _open_connection(self, relay: Relay):
self.threads[relay.url] = threading.Thread(
target=relay.connect,
args=(ssl_options, proxy),
name=f"{relay.url}-thread",
daemon=True,
)
@ -98,7 +117,7 @@ class RelayManager:
def wrap_async_queue_worker():
asyncio.run(relay.queue_worker())
self.queue_threads[relay.url] = threading.Thread(
target=wrap_async_queue_worker,
name=f"{relay.url}-queue",
@ -108,14 +127,16 @@ class RelayManager:
def _restart_relay(self, relay: Relay):
time_since_last_error = time.time() - relay.last_error_date
min_wait_time = min(60 * relay.error_counter, 60 * 60 * 24) # try at least once a day
min_wait_time = min(
60 * relay.error_counter, 60 * 60
) # try at least once an hour
if time_since_last_error < min_wait_time:
return
logger.info(f"Restarting connection to relay '{relay.url}'")
self.remove_relay(relay.url)
new_relay = self.add_relay(relay.url)
new_relay.error_counter = relay.error_counter
new_relay.error_list = relay.error_list
new_relay.error_list = relay.error_list

View file

@ -1,13 +1,7 @@
from .filter import Filters
from typing import List
class Subscription:
def __init__(self, id: str, filters: Filters=None) -> None:
def __init__(self, id: str, filters: List[str] = None) -> None:
self.id = id
self.filters = filters
def to_json_object(self):
return {
"id": self.id,
"filters": self.filters.to_json_array()
}