refactor: add_relay logic

This commit is contained in:
Vlad Stan 2023-06-22 10:02:15 +03:00
parent d08e91b2c7
commit 09d2fc0493
5 changed files with 61 additions and 61 deletions

View file

@ -1,19 +1,15 @@
from typing import * import base64
import ssl
import time
import json import json
import os import os
import base64 import time
from typing import *
from ..event import Event
from ..relay_manager import RelayManager
from ..message_type import ClientMessageType
from ..key import PrivateKey, PublicKey
from ..event import EncryptedDirectMessage, Event, EventKind
from ..filter import Filter, Filters from ..filter import Filter, Filters
from ..event import Event, EventKind, EncryptedDirectMessage from ..key import PrivateKey, PublicKey
from ..relay_manager import RelayManager
from ..message_type import ClientMessageType from ..message_type import ClientMessageType
from ..relay_manager import RelayManager
from ..subscription import Subscription
# from aes import AESCipher # from aes import AESCipher
from . import cbc from . import cbc
@ -38,12 +34,11 @@ class NostrClient:
if connect: if connect:
self.connect() self.connect()
def connect(self): async def connect(self, subscriptions: dict[str, Subscription] = {}):
for relay in self.relays: for relay in self.relays:
self.relay_manager.add_relay(relay) self.relay_manager.add_relay(relay, subscriptions)
self.relay_manager.open_connections(
{"cert_reqs": ssl.CERT_NONE}
) # NOTE: This disables ssl certificate verification
def close(self): def close(self):
self.relay_manager.close_connections() self.relay_manager.close_connections()

View file

@ -91,6 +91,12 @@ class Relay:
def publish(self, message: str): def publish(self, message: str):
self.queue.put(message) self.queue.put(message)
def publish_subscriptions(self):
for _, subscription in self.subscriptions.items():
s = subscription.to_json_object()
json_str = json.dumps(["REQ", s["id"], s["filters"][0]])
self.publish(json_str)
def queue_worker(self): def queue_worker(self):
print("#### IN !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!", self.url) print("#### IN !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!", self.url)
while True: while True:

View file

@ -1,11 +1,12 @@
import json
import ssl
import threading import threading
from .event import Event from .event import Event
from .filter import Filters from .filter import Filters
from .message_pool import MessagePool from .message_pool import MessagePool
from .message_type import ClientMessageType
from .relay import Relay, RelayPolicy from .relay import Relay, RelayPolicy
from .subscription import Subscription
class RelayException(Exception): class RelayException(Exception):
@ -20,19 +21,30 @@ class RelayManager:
self.message_pool = MessagePool() self.message_pool = MessagePool()
def add_relay( def add_relay(
self, url: str, read: bool = True, write: bool = True, subscriptions={} self, url: str, read: bool = True, write: bool = True, subscriptions: dict[str, Subscription] = {}
): ) -> Relay:
if url in self.relays: if url in self.relays:
return return
policy = RelayPolicy(read, write) policy = RelayPolicy(read, write)
relay = Relay(url, policy, self.message_pool, subscriptions.copy()) relay = Relay(url, policy, self.message_pool, subscriptions.copy())
self.relays[url] = relay self.relays[url] = relay
self.open_connection(
relay,
{"cert_reqs": ssl.CERT_NONE}
) # NOTE: This disables ssl certificate verification
relay.publish_subscriptions()
return relay
def remove_relay(self, url: str): def remove_relay(self, url: str):
self.relays[url].close()
self.relays.pop(url)
self.threads[url].join(timeout=1) self.threads[url].join(timeout=1)
self.threads.pop(url) self.threads.pop(url)
self.queue_threads[url].join(timeout=1)
self.queue_threads.pop(url)
self.relays[url].close()
self.relays.pop(url)
def add_subscription(self, id: str, filters: Filters): def add_subscription(self, id: str, filters: Filters):
for relay in self.relays.values(): for relay in self.relays.values():
@ -42,25 +54,22 @@ class RelayManager:
for relay in self.relays.values(): for relay in self.relays.values():
relay.close_subscription(id) relay.close_subscription(id)
def open_connections(self, ssl_options: dict = None, proxy: dict = None):
for relay in self.relays.values():
if relay.url not in self.threads:
self.threads[relay.url] = threading.Thread(
target=relay.connect,
args=(ssl_options, proxy),
name=f"{relay.url}-thread",
daemon=True,
)
self.threads[relay.url].start() def open_connection(self, relay: Relay, ssl_options: dict = None, proxy: dict = None):
self.threads[relay.url] = threading.Thread(
target=relay.connect,
args=(ssl_options, proxy),
name=f"{relay.url}-thread",
daemon=True,
)
self.threads[relay.url].start()
if relay.url not in self.queue_threads: self.queue_threads[relay.url] = threading.Thread(
self.queue_threads[relay.url] = threading.Thread( target=relay.queue_worker,
target=relay.queue_worker, name=f"{relay.url}-queue",
name=f"{relay.url}-queue", daemon=True,
daemon=True, )
) self.queue_threads[relay.url].start()
self.queue_threads[relay.url].start()
def close_connections(self): def close_connections(self):
for relay in self.relays.values(): for relay in self.relays.values():

View file

@ -17,31 +17,13 @@ from .services import (
async def init_relays(): async def init_relays():
# we save any subscriptions teporarily to re-add them after reinitializing the client
subscriptions = {}
for relay in nostr.client.relay_manager.relays.values():
# relay.add_subscription(id, filters)
for subscription_id, filters in relay.subscriptions.items():
subscriptions[subscription_id] = filters
# reinitialize the entire client # reinitialize the entire client
nostr.__init__() nostr.__init__()
# get relays from db # get relays from db
relays = await get_relays() relays = await get_relays()
# set relays and connect to them # set relays and connect to them
nostr.client.relays = list(set([r.url for r in relays.__root__ if r.url])) nostr.client.relays = list(set([r.url for r in relays.__root__ if r.url]))
nostr.client.connect() await nostr.client.connect()
await asyncio.sleep(2)
# re-add subscriptions
for subscription_id, subscription in subscriptions.items():
nostr.client.relay_manager.add_subscription(
subscription_id, subscription.filters
)
s = subscription.to_json_object()
json_str = json.dumps(["REQ", s["id"], s["filters"][0]])
nostr.client.relay_manager.publish_message(json_str)
return
async def subscribe_events(): async def subscribe_events():

View file

@ -1,7 +1,7 @@
import asyncio import asyncio
import json import json
from http import HTTPStatus from http import HTTPStatus
from typing import Optional from typing import List, Optional
from fastapi import Depends, WebSocket from fastapi import Depends, WebSocket
from loguru import logger from loguru import logger
@ -15,6 +15,7 @@ from .crud import add_relay, delete_relay, get_relays
from .helpers import normalize_public_key from .helpers import normalize_public_key
from .models import Relay, RelayList, TestMessage, TestMessageResponse from .models import Relay, RelayList, TestMessage, TestMessageResponse
from .nostr.key import EncryptedDirectMessage, PrivateKey from .nostr.key import EncryptedDirectMessage, PrivateKey
from .nostr.relay import Relay as NostrRelay
from .services import NostrRouter, nostr from .services import NostrRouter, nostr
from .tasks import init_relays from .tasks import init_relays
@ -60,8 +61,15 @@ async def api_add_relay(relay: Relay) -> Optional[RelayList]:
) )
relay.id = urlsafe_short_hash() relay.id = urlsafe_short_hash()
await add_relay(relay) await add_relay(relay)
# we can't add relays during runtime yet
await init_relays() all_relays: List[NostrRelay] = nostr.client.relay_manager.relays.values()
if len(all_relays):
subscriptions = all_relays[0].subscriptions
nostr.client.relays.append(relay.url)
nostr.client.relay_manager.add_relay(subscriptions)
nostr.client.relay_manager.connect_relay(relay.url)
return await get_relays() return await get_relays()