diff --git a/nostr/client/client.py b/nostr/client/client.py index 6e70f71..4d10647 100644 --- a/nostr/client/client.py +++ b/nostr/client/client.py @@ -1,19 +1,15 @@ -from typing import * -import ssl -import time +import base64 import json import os -import base64 - -from ..event import Event -from ..relay_manager import RelayManager -from ..message_type import ClientMessageType -from ..key import PrivateKey, PublicKey +import time +from typing import * +from ..event import EncryptedDirectMessage, Event, EventKind from ..filter import Filter, Filters -from ..event import Event, EventKind, EncryptedDirectMessage -from ..relay_manager import RelayManager +from ..key import PrivateKey, PublicKey from ..message_type import ClientMessageType +from ..relay_manager import RelayManager +from ..subscription import Subscription # from aes import AESCipher from . import cbc @@ -38,12 +34,11 @@ class NostrClient: if connect: self.connect() - def connect(self): + async def connect(self, subscriptions: dict[str, Subscription] = {}): for relay in self.relays: - self.relay_manager.add_relay(relay) - self.relay_manager.open_connections( - {"cert_reqs": ssl.CERT_NONE} - ) # NOTE: This disables ssl certificate verification + self.relay_manager.add_relay(relay, subscriptions) + + def close(self): self.relay_manager.close_connections() diff --git a/nostr/relay.py b/nostr/relay.py index 246b985..b6207b5 100644 --- a/nostr/relay.py +++ b/nostr/relay.py @@ -91,6 +91,12 @@ class Relay: def publish(self, message: str): self.queue.put(message) + def publish_subscriptions(self): + for _, subscription in self.subscriptions.items(): + s = subscription.to_json_object() + json_str = json.dumps(["REQ", s["id"], s["filters"][0]]) + self.publish(json_str) + def queue_worker(self): print("#### IN !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!", self.url) while True: diff --git a/nostr/relay_manager.py b/nostr/relay_manager.py index f6eba36..0ec324a 100644 --- a/nostr/relay_manager.py +++ b/nostr/relay_manager.py @@ -1,11 +1,12 @@ -import json + +import ssl import threading from .event import Event from .filter import Filters from .message_pool import MessagePool -from .message_type import ClientMessageType from .relay import Relay, RelayPolicy +from .subscription import Subscription class RelayException(Exception): @@ -20,19 +21,30 @@ class RelayManager: self.message_pool = MessagePool() def add_relay( - self, url: str, read: bool = True, write: bool = True, subscriptions={} - ): + self, url: str, read: bool = True, write: bool = True, subscriptions: dict[str, Subscription] = {} + ) -> Relay: if url in self.relays: return + policy = RelayPolicy(read, write) relay = Relay(url, policy, self.message_pool, subscriptions.copy()) self.relays[url] = relay + self.open_connection( + relay, + {"cert_reqs": ssl.CERT_NONE} + ) # NOTE: This disables ssl certificate verification + + relay.publish_subscriptions() + return relay + def remove_relay(self, url: str): - self.relays[url].close() - self.relays.pop(url) self.threads[url].join(timeout=1) self.threads.pop(url) + self.queue_threads[url].join(timeout=1) + self.queue_threads.pop(url) + self.relays[url].close() + self.relays.pop(url) def add_subscription(self, id: str, filters: Filters): for relay in self.relays.values(): @@ -42,25 +54,22 @@ class RelayManager: for relay in self.relays.values(): relay.close_subscription(id) - def open_connections(self, ssl_options: dict = None, proxy: dict = None): - for relay in self.relays.values(): - if relay.url not in self.threads: - self.threads[relay.url] = threading.Thread( - target=relay.connect, - args=(ssl_options, proxy), - name=f"{relay.url}-thread", - daemon=True, - ) - self.threads[relay.url].start() + def open_connection(self, relay: Relay, ssl_options: dict = None, proxy: dict = None): + self.threads[relay.url] = threading.Thread( + target=relay.connect, + args=(ssl_options, proxy), + name=f"{relay.url}-thread", + daemon=True, + ) + self.threads[relay.url].start() - if relay.url not in self.queue_threads: - self.queue_threads[relay.url] = threading.Thread( - target=relay.queue_worker, - name=f"{relay.url}-queue", - daemon=True, - ) - self.queue_threads[relay.url].start() + self.queue_threads[relay.url] = threading.Thread( + target=relay.queue_worker, + name=f"{relay.url}-queue", + daemon=True, + ) + self.queue_threads[relay.url].start() def close_connections(self): for relay in self.relays.values(): diff --git a/tasks.py b/tasks.py index beff9db..eb5391a 100644 --- a/tasks.py +++ b/tasks.py @@ -17,31 +17,13 @@ from .services import ( async def init_relays(): - # we save any subscriptions teporarily to re-add them after reinitializing the client - subscriptions = {} - for relay in nostr.client.relay_manager.relays.values(): - # relay.add_subscription(id, filters) - for subscription_id, filters in relay.subscriptions.items(): - subscriptions[subscription_id] = filters - # reinitialize the entire client nostr.__init__() # get relays from db relays = await get_relays() # set relays and connect to them nostr.client.relays = list(set([r.url for r in relays.__root__ if r.url])) - nostr.client.connect() - - await asyncio.sleep(2) - # re-add subscriptions - for subscription_id, subscription in subscriptions.items(): - nostr.client.relay_manager.add_subscription( - subscription_id, subscription.filters - ) - s = subscription.to_json_object() - json_str = json.dumps(["REQ", s["id"], s["filters"][0]]) - nostr.client.relay_manager.publish_message(json_str) - return + await nostr.client.connect() async def subscribe_events(): diff --git a/views_api.py b/views_api.py index 12f4f79..88193ad 100644 --- a/views_api.py +++ b/views_api.py @@ -1,7 +1,7 @@ import asyncio import json from http import HTTPStatus -from typing import Optional +from typing import List, Optional from fastapi import Depends, WebSocket from loguru import logger @@ -15,6 +15,7 @@ from .crud import add_relay, delete_relay, get_relays from .helpers import normalize_public_key from .models import Relay, RelayList, TestMessage, TestMessageResponse from .nostr.key import EncryptedDirectMessage, PrivateKey +from .nostr.relay import Relay as NostrRelay from .services import NostrRouter, nostr from .tasks import init_relays @@ -60,8 +61,15 @@ async def api_add_relay(relay: Relay) -> Optional[RelayList]: ) relay.id = urlsafe_short_hash() await add_relay(relay) - # we can't add relays during runtime yet - await init_relays() + + all_relays: List[NostrRelay] = nostr.client.relay_manager.relays.values() + if len(all_relays): + subscriptions = all_relays[0].subscriptions + nostr.client.relays.append(relay.url) + nostr.client.relay_manager.add_relay(subscriptions) + + nostr.client.relay_manager.connect_relay(relay.url) + return await get_relays()