feat: listen for direct messages

This commit is contained in:
Vlad Stan 2023-03-03 18:24:53 +02:00
parent 9a82577493
commit cec7d2ee25
6 changed files with 142 additions and 11 deletions

View file

@ -1,4 +1,5 @@
import asyncio import asyncio
from asyncio import Task
from typing import List from typing import List
from fastapi import APIRouter from fastapi import APIRouter
@ -25,9 +26,9 @@ def nostrmarket_renderer():
return template_renderer(["lnbits/extensions/nostrmarket/templates"]) return template_renderer(["lnbits/extensions/nostrmarket/templates"])
scheduled_tasks: List[asyncio.Task] = [] scheduled_tasks: List[Task] = []
from .tasks import subscribe_nostrclient_ws, wait_for_paid_invoices from .tasks import subscribe_nostrclient, wait_for_nostr_events, wait_for_paid_invoices
from .views import * # noqa from .views import * # noqa
from .views_api import * # noqa from .views_api import * # noqa
@ -35,5 +36,6 @@ from .views_api import * # noqa
def nostrmarket_start(): def nostrmarket_start():
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
task1 = loop.create_task(catch_everything_and_restart(wait_for_paid_invoices)) task1 = loop.create_task(catch_everything_and_restart(wait_for_paid_invoices))
task2 = loop.create_task(catch_everything_and_restart(subscribe_nostrclient_ws)) task2 = loop.create_task(catch_everything_and_restart(subscribe_nostrclient))
scheduled_tasks.append([task1, task2]) task3 = loop.create_task(catch_everything_and_restart(wait_for_nostr_events))
scheduled_tasks.append([task1, task2, task3])

17
crud.py
View file

@ -45,6 +45,23 @@ async def get_merchant(user_id: str, merchant_id: str) -> Optional[Merchant]:
return Merchant.from_row(row) if row else None return Merchant.from_row(row) if row else None
async def get_merchant_by_pubkey(public_key: str) -> Optional[Merchant]:
row = await db.fetchone(
"""SELECT * FROM nostrmarket.merchants WHERE public_key = ? """,
(public_key,),
)
return Merchant.from_row(row) if row else None
async def get_public_keys_for_merchants() -> List[str]:
rows = await db.fetchall(
"""SELECT public_key FROM nostrmarket.merchants""",
)
return [row[0] for row in rows]
async def get_merchant_for_user(user_id: str) -> Optional[Merchant]: async def get_merchant_for_user(user_id: str) -> Optional[Merchant]:
row = await db.fetchone( row = await db.fetchone(
"""SELECT * FROM nostrmarket.merchants WHERE user_id = ? """, """SELECT * FROM nostrmarket.merchants WHERE user_id = ? """,

View file

@ -6,7 +6,7 @@ from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel
from .helpers import sign_message_hash from .helpers import decrypt_message, get_shared_secret, sign_message_hash
from .nostr.event import NostrEvent from .nostr.event import NostrEvent
######################################## NOSTR ######################################## ######################################## NOSTR ########################################
@ -39,6 +39,10 @@ class Merchant(PartialMerchant):
def sign_hash(self, hash: bytes) -> str: def sign_hash(self, hash: bytes) -> str:
return sign_message_hash(self.private_key, hash) return sign_message_hash(self.private_key, hash)
def decrypt_message(self, encrypted_message: str, public_key: str) -> str:
encryption_key = get_shared_secret(self.private_key, public_key)
return decrypt_message(encrypted_message, encryption_key)
@classmethod @classmethod
def from_row(cls, row: Row) -> "Merchant": def from_row(cls, row: Row) -> "Merchant":
merchant = cls(**dict(row)) merchant = cls(**dict(row))

View file

@ -1,5 +1,9 @@
from threading import Thread
from typing import Callable
import httpx import httpx
from loguru import logger from loguru import logger
from websocket import WebSocketApp
from lnbits.app import settings from lnbits.app import settings
from lnbits.helpers import url_for from lnbits.helpers import url_for
@ -10,7 +14,7 @@ from .event import NostrEvent
async def publish_nostr_event(e: NostrEvent): async def publish_nostr_event(e: NostrEvent):
url = url_for("/nostrclient/api/v1/publish", external=True) url = url_for("/nostrclient/api/v1/publish", external=True)
data = dict(e) data = dict(e)
# print("### published", dict(data)) print("### published", dict(data))
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
try: try:
await client.post( await client.post(
@ -19,3 +23,44 @@ async def publish_nostr_event(e: NostrEvent):
) )
except Exception as ex: except Exception as ex:
logger.warning(ex) logger.warning(ex)
async def connect_to_nostrclient_ws(
on_open: Callable, on_message: Callable
) -> WebSocketApp:
def on_error(_, error):
logger.warning(error)
logger.debug(f"Subscribing to websockets for nostrclient extension")
ws = WebSocketApp(
f"ws://localhost:{settings.port}/nostrclient/api/v1/filters",
on_message=on_message,
on_open=on_open,
on_error=on_error,
)
wst = Thread(target=ws.run_forever)
wst.daemon = True
wst.start()
return ws
async def handle_event(event, pubkeys):
tags = [t[1] for t in event["tags"] if t[0] == "p"]
to_merchant = None
if tags and len(tags) > 0:
to_merchant = tags[0]
if event["pubkey"] in pubkeys or to_merchant in pubkeys:
logger.debug(f"Event sent to {to_merchant}")
pubkey = to_merchant if to_merchant in pubkeys else event["pubkey"]
# Send event to market extension
await send_event_to_market(event=event, pubkey=pubkey)
async def send_event_to_market(event: dict, pubkey: str):
# Sends event to market extension, for decrypt and handling
market_url = url_for(f"/market/api/v1/nip04/{pubkey}", external=True)
async with httpx.AsyncClient() as client:
await client.post(url=market_url, json=event)

View file

@ -1,18 +1,25 @@
import asyncio import asyncio
import json import json
import threading from asyncio import Queue
import httpx import httpx
import websocket import websocket
from loguru import logger from loguru import logger
from websocket import WebSocketApp
from lnbits.core.models import Payment from lnbits.core.models import Payment
from lnbits.helpers import url_for
from lnbits.tasks import register_invoice_listener from lnbits.tasks import register_invoice_listener
from .crud import get_merchant, get_merchant_by_pubkey, get_public_keys_for_merchants
from .nostr.event import NostrEvent
from .nostr.nostr_client import connect_to_nostrclient_ws
recieve_event_queue: Queue = Queue()
send_req_queue: Queue = Queue()
async def wait_for_paid_invoices(): async def wait_for_paid_invoices():
invoice_queue = asyncio.Queue() invoice_queue = Queue()
register_invoice_listener(invoice_queue) register_invoice_listener(invoice_queue)
while True: while True:
@ -27,5 +34,61 @@ async def on_invoice_paid(payment: Payment) -> None:
print("### on_invoice_paid") print("### on_invoice_paid")
async def subscribe_nostrclient_ws(): async def subscribe_nostrclient():
print("### subscribe_nostrclient_ws") print("### subscribe_nostrclient_ws")
def on_open(_):
logger.info("Connected to 'nostrclient' websocket")
def on_message(_, message):
print("### on_message", message)
recieve_event_queue.put_nowait(message)
# wait for 'nostrclient' extension to initialize
await asyncio.sleep(5)
ws: WebSocketApp = None
while True:
try:
req = None
if not ws:
ws = await connect_to_nostrclient_ws(on_open, on_message)
# be sure the connection is open
await asyncio.sleep(3)
req = await send_req_queue.get()
print("### req", req)
ws.send(json.dumps(req))
except Exception as ex:
logger.warning(ex)
if req:
await send_req_queue.put(req)
ws = None # todo close
await asyncio.sleep(5)
async def wait_for_nostr_events():
public_keys = await get_public_keys_for_merchants()
for p in public_keys:
await send_req_queue.put(
["REQ", f"direct-messages:{p}", {"kind": 4, "#p": [p]}]
)
while True:
message = await recieve_event_queue.get()
await handle_message(message)
async def handle_message(msg: str):
try:
type, subscription_id, event = json.loads(msg)
_, public_key = subscription_id.split(":")
if type.upper() == "EVENT":
event = NostrEvent(**event)
if event.kind == 4:
merchant = await get_merchant_by_pubkey(public_key)
if not merchant:
return
clear_text_msg = merchant.decrypt_message(event.content, event.pubkey)
print("### clear_text_msg", clear_text_msg)
except Exception as ex:
logger.warning(ex)

View file

@ -12,7 +12,6 @@ from lnbits.decorators import (
require_admin_key, require_admin_key,
require_invoice_key, require_invoice_key,
) )
from lnbits.extensions.nostrmarket.nostr.event import NostrEvent
from lnbits.utils.exchange_rates import currencies from lnbits.utils.exchange_rates import currencies
from . import nostrmarket_ext from . import nostrmarket_ext
@ -46,6 +45,7 @@ from .models import (
Stall, Stall,
Zone, Zone,
) )
from .nostr.event import NostrEvent
from .nostr.nostr_client import publish_nostr_event from .nostr.nostr_client import publish_nostr_event
######################################## MERCHANT ######################################## ######################################## MERCHANT ########################################