diff --git a/crud.py b/crud.py index ad0de8e..339ccfe 100644 --- a/crud.py +++ b/crud.py @@ -1,5 +1,4 @@ import json -import time from typing import List, Optional from lnbits.helpers import urlsafe_short_hash @@ -51,9 +50,8 @@ async def update_merchant( ) return await get_merchant(user_id, merchant_id) -async def touch_merchant( - user_id: str, merchant_id: str -) -> Optional[Merchant]: + +async def touch_merchant(user_id: str, merchant_id: str) -> Optional[Merchant]: await db.execute( f""" UPDATE nostrmarket.merchants SET time = {db.timestamp_now} @@ -225,20 +223,25 @@ async def get_stall(merchant_id: str, stall_id: str) -> Optional[Stall]: async def get_stalls(merchant_id: str, pending: Optional[bool] = False) -> List[Stall]: rows = await db.fetchall( "SELECT * FROM nostrmarket.stalls WHERE merchant_id = ? AND pending = ?", - (merchant_id, pending,), + ( + merchant_id, + pending, + ), ) return [Stall.from_row(row) for row in rows] -async def get_last_stall_update_time(merchant_id: str) -> int: + +async def get_last_stall_update_time() -> int: row = await db.fetchone( """ SELECT event_created_at FROM nostrmarket.stalls - WHERE merchant_id = ? ORDER BY event_created_at DESC LIMIT 1 + ORDER BY event_created_at DESC LIMIT 1 """, - (merchant_id,), + (), ) return row[0] or 0 if row else 0 + async def update_stall(merchant_id: str, stall: Stall) -> Optional[Stall]: await db.execute( f""" @@ -257,7 +260,7 @@ async def update_stall(merchant_id: str, stall: Stall) -> Optional[Stall]: ), # todo: cost is float. should be int for sats json.dumps(stall.config.dict()), merchant_id, - stall.id + stall.id, ), ) return await get_stall(merchant_id, stall.id) @@ -366,7 +369,9 @@ async def get_product(merchant_id: str, product_id: str) -> Optional[Product]: return Product.from_row(row) if row else None -async def get_products(merchant_id: str, stall_id: str, pending: Optional[bool] = False) -> List[Product]: +async def get_products( + merchant_id: str, stall_id: str, pending: Optional[bool] = False +) -> List[Product]: rows = await db.fetchall( "SELECT * FROM nostrmarket.products WHERE merchant_id = ? AND stall_id = ? AND pending = ?", (merchant_id, stall_id, pending), @@ -401,16 +406,18 @@ async def get_wallet_for_product(product_id: str) -> Optional[str]: ) return row[0] if row else None -async def get_last_product_update_time(merchant_id: str) -> int: + +async def get_last_product_update_time() -> int: row = await db.fetchone( """ SELECT event_created_at FROM nostrmarket.products - WHERE merchant_id = ? ORDER BY event_created_at DESC LIMIT 1 + ORDER BY event_created_at DESC LIMIT 1 """, - (merchant_id,), + (), ) return row[0] or 0 if row else 0 + async def delete_product(merchant_id: str, product_id: str) -> None: await db.execute( "DELETE FROM nostrmarket.products WHERE merchant_id =? AND id = ?", @@ -530,24 +537,13 @@ async def get_orders_for_stall( return [Order.from_row(row) for row in rows] -async def get_last_order_time(merchant_id: str) -> int: - row = await db.fetchone( - """ - SELECT event_created_at FROM nostrmarket.orders - WHERE merchant_id = ? ORDER BY event_created_at DESC LIMIT 1 - """, - (merchant_id,), - ) - return row[0] if row else 0 - - async def update_order(merchant_id: str, order_id: str, **kwargs) -> Optional[Order]: q = ", ".join([f"{field[0]} = ?" for field in kwargs.items()]) await db.execute( f""" UPDATE nostrmarket.orders SET {q} WHERE merchant_id = ? and id = ? - """, - (*kwargs.values(), merchant_id, order_id) + """, + (*kwargs.values(), merchant_id, order_id), ) return await get_order(merchant_id, order_id) @@ -650,6 +646,7 @@ async def get_direct_messages(merchant_id: str, public_key: str) -> List[DirectM ) return [DirectMessage.from_row(row) for row in rows] + async def get_orders_from_direct_messages(merchant_id: str) -> List[DirectMessage]: rows = await db.fetchall( "SELECT * FROM nostrmarket.direct_messages WHERE merchant_id = ? AND type >= 0 ORDER BY event_created_at, type", @@ -669,19 +666,17 @@ async def get_last_direct_messages_time(merchant_id: str) -> int: return row[0] if row else 0 - -async def get_last_direct_messages_created_at(merchant_id: str) -> int: +async def get_last_direct_messages_created_at() -> int: row = await db.fetchone( """ SELECT event_created_at FROM nostrmarket.direct_messages - WHERE merchant_id = ? ORDER BY event_created_at DESC LIMIT 1 + ORDER BY event_created_at DESC LIMIT 1 """, - (merchant_id,), + (), ) return row[0] if row else 0 - async def delete_merchant_direct_messages(merchant_id: str) -> None: await db.execute( "DELETE FROM nostrmarket.direct_messages WHERE merchant_id = ?", @@ -750,12 +745,19 @@ async def update_customer_profile( async def increment_customer_unread_messages(merchant_id: str, public_key: str): await db.execute( f"UPDATE nostrmarket.customers SET unread_messages = unread_messages + 1 WHERE merchant_id = ? AND public_key = ?", - (merchant_id, public_key,), + ( + merchant_id, + public_key, + ), ) -#??? two merchants + +# ??? two merchants async def update_customer_no_unread_messages(merchant_id: str, public_key: str): await db.execute( f"UPDATE nostrmarket.customers SET unread_messages = 0 WHERE merchant_id =? AND public_key = ?", - (merchant_id, public_key,), + ( + merchant_id, + public_key, + ), ) diff --git a/nostr/nostr_client.py b/nostr/nostr_client.py index 8e8cd2e..20e9bd5 100644 --- a/nostr/nostr_client.py +++ b/nostr/nostr_client.py @@ -8,6 +8,7 @@ from loguru import logger from websocket import WebSocketApp from lnbits.app import settings +from lnbits.helpers import urlsafe_short_hash from .event import NostrEvent @@ -17,7 +18,7 @@ class NostrClient: self.recieve_event_queue: Queue = Queue() self.send_req_queue: Queue = Queue() self.ws: WebSocketApp = None - + self.subscription_id = "nostrmarket-" + urlsafe_short_hash()[:32] async def connect_to_nostrclient_ws( self, on_open: Callable, on_message: Callable @@ -62,6 +63,7 @@ class NostrClient: # be sure the connection is open await asyncio.sleep(3) req = await self.send_req_queue.get() + if isinstance(req, ValueError): running = False logger.warning(str(req)) @@ -77,68 +79,101 @@ class NostrClient: async def publish_nostr_event(self, e: NostrEvent): await self.send_req_queue.put(["EVENT", e.dict()]) - async def subscribe_merchant(self, public_key: str, since = 0): - dm_filters = self._filters_for_direct_messages(public_key, since) - stall_filters = self._filters_for_stall_events(public_key, since) - product_filters = self._filters_for_product_events(public_key, since) - profile_filters = self._filters_for_user_profile(public_key, since) + async def subscribe_merchants( + self, + public_keys: List[str], + dm_time=0, + stall_time=0, + product_time=0, + profile_time=0, + ): + dm_filters = self._filters_for_direct_messages(public_keys, dm_time) + stall_filters = self._filters_for_stall_events(public_keys, stall_time) + product_filters = self._filters_for_product_events(public_keys, product_time) + profile_filters = self._filters_for_user_profile(public_keys, profile_time) - merchant_filters = dm_filters + stall_filters + product_filters + profile_filters - - await self.send_req_queue.put( - ["REQ", f"merchant:{public_key}"] + merchant_filters + merchant_filters = ( + dm_filters + stall_filters + product_filters + profile_filters ) - logger.debug(f"Subscribed to events for: '{public_key}'.") - + self.subscription_id = "nostrmarket-" + urlsafe_short_hash()[:32] + await self.send_req_queue.put(["REQ", self.subscription_id] + merchant_filters) - def _filters_for_direct_messages(self, public_key: str, since: int) -> List: - in_messages_filter = {"kinds": [4], "#p": [public_key]} - out_messages_filter = {"kinds": [4], "authors": [public_key]} + logger.debug( + f"Subscribed to events for: {len(public_keys)} keys. New subscription id: {self.subscription_id}" + ) + + async def merchant_temp_subscription(self, pk, duration=5): + dm_filters = self._filters_for_direct_messages([pk], 0) + stall_filters = self._filters_for_stall_events([pk], 0) + product_filters = self._filters_for_product_events([pk], 0) + profile_filters = self._filters_for_user_profile([pk], 0) + + merchant_filters = ( + dm_filters + stall_filters + product_filters + profile_filters + ) + + subscription_id = "merchant-" + urlsafe_short_hash()[:32] + logger.debug( + f"New merchant temp subscription ({duration} sec). Subscription id: {subscription_id}" + ) + await self.send_req_queue.put(["REQ", subscription_id] + merchant_filters) + + async def unsubscribe_with_delay(sub_id, d): + await asyncio.sleep(d) + await self.unsubscribe(sub_id) + + asyncio.create_task(unsubscribe_with_delay(subscription_id, duration)) + + async def user_profile_temp_subscribe(self, public_key: str, duration=30) -> List: + try: + profile_filter = [{"kinds": [0], "authors": [public_key]}] + subscription_id = "profile-" + urlsafe_short_hash()[:32] + logger.debug( + f"New user temp subscription ({duration} sec). Subscription id: {subscription_id}" + ) + await self.send_req_queue.put(["REQ", subscription_id] + profile_filter) + + async def unsubscribe_with_delay(sub_id, d): + await asyncio.sleep(d) + await self.unsubscribe(sub_id) + + asyncio.create_task(unsubscribe_with_delay(subscription_id, duration)) + except Exception as ex: + logger.debug(ex) + + def _filters_for_direct_messages(self, public_keys: List[str], since: int) -> List: + in_messages_filter = {"kinds": [4], "#p": public_keys} + out_messages_filter = {"kinds": [4], "authors": public_keys} if since and since != 0: in_messages_filter["since"] = since out_messages_filter["since"] = since return [in_messages_filter, out_messages_filter] - def _filters_for_stall_events(self, public_key: str, since: int) -> List: - stall_filter = {"kinds": [30017], "authors": [public_key]} + def _filters_for_stall_events(self, public_keys: List[str], since: int) -> List: + stall_filter = {"kinds": [30017], "authors": public_keys} if since and since != 0: stall_filter["since"] = since return [stall_filter] - def _filters_for_product_events(self, public_key: str, since: int) -> List: - product_filter = {"kinds": [30018], "authors": [public_key]} + def _filters_for_product_events(self, public_keys: List[str], since: int) -> List: + product_filter = {"kinds": [30018], "authors": public_keys} if since and since != 0: product_filter["since"] = since return [product_filter] - - def _filters_for_user_profile(self, public_key: str, since: int) -> List: - profile_filter = {"kinds": [0], "authors": [public_key]} + def _filters_for_user_profile(self, public_keys: List[str], since: int) -> List: + profile_filter = {"kinds": [0], "authors": public_keys} if since and since != 0: profile_filter["since"] = since return [profile_filter] - - def subscribe_to_user_profile(self, public_key: str, since: int) -> List: - profile_filter = {"kinds": [0], "authors": [public_key]} - if since and since != 0: - profile_filter["since"] = since - - # Disabled for now. The number of clients can grow large. - # Some relays only allow a small number of subscriptions. - # There is the risk that more important subscriptions will be blocked. - # await self.send_req_queue.put( - # ["REQ", f"user-profile-events:{public_key}", profile_filter] - # ) - - - async def restart(self, public_keys: List[str]): - await self.unsubscribe_merchants(public_keys) + async def restart(self): + await self.unsubscribe_merchants() # Give some time for the CLOSE events to propagate before restarting await asyncio.sleep(10) @@ -149,24 +184,20 @@ class NostrClient: self.ws.close() self.ws = None + async def stop(self): + await self.unsubscribe_merchants() - async def stop(self, public_keys: List[str]): - await self.unsubscribe_merchants(public_keys) - - # Give some time for the CLOSE events to propagate before closing the connection - await asyncio.sleep(10) + # Give some time for the CLOSE events to propagate before closing the connection + await asyncio.sleep(10) self.ws.close() self.ws = None - async def unsubscribe_merchant(self, public_key: str): - await self.send_req_queue.put(["CLOSE", public_key]) + async def unsubscribe_merchants(self): + await self.send_req_queue.put(["CLOSE", self.subscription_id]) + logger.debug( + f"Unsubscribed from all merchants events. Subscription id: {self.subscription_id}" + ) - logger.debug(f"Unsubscribed from merchant events: '{public_key}'.") - - async def unsubscribe_merchants(self, public_keys: List[str]): - for pk in public_keys: - try: - await self.unsubscribe_merchant(pk) - except Exception as ex: - logger.warning(ex) - + async def unsubscribe(self, subscription_id): + await self.send_req_queue.put(["CLOSE", subscription_id]) + logger.debug(f"Unsubscribed from subscription id: {subscription_id}") diff --git a/services.py b/services.py index 1b484de..dcdb057 100644 --- a/services.py +++ b/services.py @@ -17,10 +17,10 @@ from .crud import ( create_stall, get_customer, get_last_direct_messages_created_at, - get_last_order_time, get_last_product_update_time, get_last_stall_update_time, get_merchant_by_pubkey, + get_merchants_ids_with_pubkeys, get_order, get_order_by_event_id, get_products, @@ -205,9 +205,14 @@ async def notify_client_of_order_status( else: dm_content = f"Order cannot be fulfilled. Reason: {message}" - dm_type = DirectMessageType.ORDER_PAID_OR_SHIPPED.value if success else DirectMessageType.PLAIN_TEXT.value + dm_type = ( + DirectMessageType.ORDER_PAID_OR_SHIPPED.value + if success + else DirectMessageType.PLAIN_TEXT.value + ) await send_dm(merchant, order.public_key, dm_type, dm_content) + async def update_products_for_order( merchant: Merchant, order: Order ) -> Tuple[bool, str]: @@ -232,15 +237,22 @@ async def autoreply_for_products_in_order( product_ids = [i.product_id for i in order.items] products = await get_products_by_ids(merchant.id, product_ids) - products_with_autoreply = [p for p in products if p.config.use_autoreply] + products_with_autoreply = [p for p in products if p.config.use_autoreply] for p in products_with_autoreply: dm_content = p.config.autoreply_message - await send_dm(merchant, order.public_key, DirectMessageType.PLAIN_TEXT.value, dm_content) - await asyncio.sleep(1) # do not send all autoreplies at once + await send_dm( + merchant, order.public_key, DirectMessageType.PLAIN_TEXT.value, dm_content + ) + await asyncio.sleep(1) # do not send all autoreplies at once -async def send_dm(merchant: Merchant, other_pubkey: str, type: str, dm_content: str,): +async def send_dm( + merchant: Merchant, + other_pubkey: str, + type: str, + dm_content: str, +): dm_event = merchant.build_dm_event(dm_content, other_pubkey) dm = PartialDirectMessage( @@ -248,23 +260,24 @@ async def send_dm(merchant: Merchant, other_pubkey: str, type: str, dm_content: event_created_at=dm_event.created_at, message=dm_content, public_key=other_pubkey, - type=type + type=type, ) dm_reply = await create_direct_message(merchant.id, dm) await nostr_client.publish_nostr_event(dm_event) await websocketUpdater( - merchant.id, - json.dumps( - { - "type": f"dm:{dm.type}", - "customerPubkey": other_pubkey, - "dm": dm_reply.dict(), - } - ), + merchant.id, + json.dumps( + { + "type": f"dm:{dm.type}", + "customerPubkey": other_pubkey, + "dm": dm_reply.dict(), + } + ), ) + async def compute_products_new_quantity( merchant_id: str, product_ids: List[str], items: List[OrderItem] ) -> Tuple[bool, List[Product], str]: @@ -293,13 +306,12 @@ async def process_nostr_message(msg: str): type, *rest = json.loads(msg) if type.upper() == "EVENT": - subscription_id, event = rest + _, event = rest event = NostrEvent(**event) if event.kind == 0: await _handle_customer_profile_update(event) elif event.kind == 4: - _, merchant_public_key = subscription_id.split(":") - await _handle_nip04_message(merchant_public_key, event) + await _handle_nip04_message(event) elif event.kind == 30017: await _handle_stall(event) elif event.kind == 30018: @@ -385,16 +397,19 @@ async def extract_customer_order_from_dm( return order -async def get_last_event_date_for_merchant(id) -> int: - last_order_time = await get_last_order_time(id) - last_dm_time = await get_last_direct_messages_created_at(id) - last_stall_update = await get_last_stall_update_time(id) - last_product_update = await get_last_product_update_time(id) - return max(last_order_time, last_dm_time, last_stall_update, last_product_update) - - -async def _handle_nip04_message(merchant_public_key: str, event: NostrEvent): +async def _handle_nip04_message(event: NostrEvent): + merchant_public_key = event.pubkey merchant = await get_merchant_by_pubkey(merchant_public_key) + + if not merchant: + p_tags = event.tag_values("p") + merchant_public_key = p_tags[0] if len(p_tags) else None + merchant = ( + await get_merchant_by_pubkey(merchant_public_key) + if merchant_public_key + else None + ) + assert merchant, f"Merchant not found for public key '{merchant_public_key}'" if event.pubkey == merchant_public_key: @@ -550,6 +565,7 @@ async def _handle_new_order( payment_req = await create_new_order(merchant_public_key, partial_order) except Exception as e: + logger.debug(e) payment_req = await create_new_failed_order( merchant_id, merchant_public_key, dm, json_data, str(e) ) @@ -575,12 +591,28 @@ async def create_new_failed_order( await create_order(merchant_id, order) return PaymentRequest(id=order.id, message=fail_message, payment_options=[]) +async def resubscribe_to_all_merchants(): + await nostr_client.unsubscribe_merchants() + # give some time for the message to propagate + asyncio.sleep(1) + await subscribe_to_all_merchants() -async def _handle_new_customer(event, merchant: Merchant): +async def subscribe_to_all_merchants(): + ids = await get_merchants_ids_with_pubkeys() + public_keys = [pk for _, pk in ids] + + last_dm_time = await get_last_direct_messages_created_at() + last_stall_time = await get_last_stall_update_time() + last_prod_time = await get_last_product_update_time() + + await nostr_client.subscribe_merchants(public_keys, last_dm_time, last_stall_time, last_prod_time, 0) + + +async def _handle_new_customer(event: NostrEvent, merchant: Merchant): await create_customer( merchant.id, Customer(merchant_id=merchant.id, public_key=event.pubkey) ) - await nostr_client.subscribe_to_user_profile(event.pubkey, 0) + await nostr_client.user_profile_temp_subscribe(event.pubkey) async def _handle_customer_profile_update(event: NostrEvent): diff --git a/tasks.py b/tasks.py index e02f781..0813936 100644 --- a/tasks.py +++ b/tasks.py @@ -1,19 +1,14 @@ from asyncio import Queue -import asyncio from lnbits.core.models import Payment from lnbits.tasks import register_invoice_listener -from .crud import ( - get_all_unique_customers, - get_last_direct_messages_created_at, - get_last_order_time, - get_last_product_update_time, - get_last_stall_update_time, - get_merchants_ids_with_pubkeys, -) from .nostr.nostr_client import NostrClient -from .services import get_last_event_date_for_merchant, handle_order_paid, process_nostr_message +from .services import ( + handle_order_paid, + process_nostr_message, + subscribe_to_all_merchants, +) async def wait_for_paid_invoices(): @@ -38,16 +33,9 @@ async def on_invoice_paid(payment: Payment) -> None: async def wait_for_nostr_events(nostr_client: NostrClient): - merchant_ids = await get_merchants_ids_with_pubkeys() - for id, pk in merchant_ids: - since = await get_last_event_date_for_merchant(id) - await nostr_client.subscribe_merchant(pk, since + 1) - await asyncio.sleep(0.1) # try to avoid 'too many concurrent REQ' from relays - # customers = await get_all_unique_customers() - # for c in customers: - # await nostr_client.subscribe_to_user_profile(c.public_key, c.event_created_at) + await subscribe_to_all_merchants() while True: message = await nostr_client.get_event() - await process_nostr_message(message) + await process_nostr_message(message) \ No newline at end of file diff --git a/views_api.py b/views_api.py index a1cde12..457f659 100644 --- a/views_api.py +++ b/views_api.py @@ -40,7 +40,6 @@ from .crud import ( get_last_direct_messages_time, get_merchant_by_pubkey, get_merchant_for_user, - get_merchants_ids_with_pubkeys, get_order, get_order_by_event_id, get_orders, @@ -86,7 +85,9 @@ from .services import ( reply_to_structured_dm, build_order_with_payment, create_or_update_order_from_dm, + resubscribe_to_all_merchants, sign_and_send_to_nostr, + subscribe_to_all_merchants, update_merchant_to_nostr, ) @@ -119,7 +120,9 @@ async def api_create_merchant( ), ) - await nostr_client.subscribe_merchant(data.public_key, 0) + await resubscribe_to_all_merchants() + + await nostr_client.merchant_temp_subscription(data.public_key) return merchant except AssertionError as ex: @@ -170,8 +173,7 @@ async def api_delete_merchant( assert merchant, "Merchant cannot be found" assert merchant.id == merchant_id, "Wrong merchant ID" - # first unsubscribe so new events are not created during the clean-up - await nostr_client.unsubscribe_merchant(merchant.public_key) + await nostr_client.unsubscribe_merchants() await delete_merchant_orders(merchant.id) await delete_merchant_products(merchant.id) @@ -180,6 +182,7 @@ async def api_delete_merchant( await delete_merchant_zones(merchant.id) await delete_merchant(merchant.id) + except AssertionError as ex: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, @@ -191,7 +194,8 @@ async def api_delete_merchant( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Cannot get merchant", ) - + finally: + await subscribe_to_all_merchants() @nostrmarket_ext.put("/api/v1/merchant/{merchant_id}/nostr") async def api_republish_merchant( @@ -1057,7 +1061,8 @@ async def api_create_customer( customer = await create_customer( merchant.id, Customer(merchant_id=merchant.id, public_key=pubkey) ) - await nostr_client.subscribe_to_user_profile(pubkey, 0) + + await nostr_client.user_profile_temp_subscribe(pubkey) return customer except (ValueError, AssertionError) as ex: @@ -1084,9 +1089,7 @@ async def api_list_currencies_available(): @nostrmarket_ext.put("/api/v1/restart") async def restart_nostr_client(wallet: WalletTypeInfo = Depends(require_admin_key)): try: - ids = await get_merchants_ids_with_pubkeys() - merchant_public_keys = [id[0] for id in ids] - await nostr_client.restart(merchant_public_keys) + await nostr_client.restart() except Exception as ex: logger.warning(ex) @@ -1100,9 +1103,7 @@ async def api_stop(wallet: WalletTypeInfo = Depends(check_admin)): logger.warning(ex) try: - ids = await get_merchants_ids_with_pubkeys() - merchant_public_keys = [id[0] for id in ids] - await nostr_client.stop(merchant_public_keys) + await nostr_client.stop() except Exception as ex: logger.warning(ex)