diff --git a/__init__.py b/__init__.py index 56cd641..dc64349 100644 --- a/__init__.py +++ b/__init__.py @@ -1,6 +1,10 @@ +import asyncio + from fastapi import APIRouter +from loguru import logger from .crud import db +from .tasks import wait_for_paid_invoices from .views import castle_generic_router from .views_api import castle_api_router @@ -15,4 +19,11 @@ castle_static_files = [ } ] -__all__ = ["castle_ext", "castle_static_files", "db"] + +def castle_start(): + """Initialize Castle extension background tasks""" + logger.info("Starting Castle accounting extension background tasks") + asyncio.create_task(wait_for_paid_invoices()) + + +__all__ = ["castle_ext", "castle_static_files", "db", "castle_start"] diff --git a/crud.py b/crud.py index 0ac20fb..81e601a 100644 --- a/crud.py +++ b/crud.py @@ -226,6 +226,20 @@ async def get_journal_entry(entry_id: str) -> Optional[JournalEntry]: return entry +async def get_journal_entry_by_reference(reference: str) -> Optional[JournalEntry]: + """Get a journal entry by its reference field (e.g., payment_hash)""" + entry = await db.fetchone( + "SELECT * FROM journal_entries WHERE reference = :reference", + {"reference": reference}, + JournalEntry, + ) + + if entry: + entry.lines = await get_entry_lines(entry.id) + + return entry + + async def get_entry_lines(journal_entry_id: str) -> list[EntryLine]: rows = await db.fetchall( "SELECT * FROM entry_lines WHERE journal_entry_id = :id", diff --git a/tasks.py b/tasks.py index 991eaaf..d0a31a9 100644 --- a/tasks.py +++ b/tasks.py @@ -4,10 +4,13 @@ These tasks handle automated reconciliation checks and maintenance. """ import asyncio +from asyncio import Queue from datetime import datetime from typing import Optional +from lnbits.core.models import Payment from lnbits.tasks import register_invoice_listener +from loguru import logger from .crud import check_balance_assertion, get_balance_assertions from .models import AssertionStatus @@ -106,3 +109,103 @@ def start_daily_reconciliation_task(): print("[CASTLE] Daily reconciliation task registered") # In a production system, you would register this with LNbits task scheduler # For now, it can be triggered manually via API endpoint + + +async def wait_for_paid_invoices(): + """ + Background task that listens for paid invoices and automatically + records them in the accounting system. + + This ensures payments are recorded even if the user closes their browser + before the payment is detected by client-side polling. + """ + invoice_queue = Queue() + register_invoice_listener(invoice_queue, "ext_castle") + + while True: + payment = await invoice_queue.get() + await on_invoice_paid(payment) + + +async def on_invoice_paid(payment: Payment) -> None: + """ + Handle a paid Castle invoice by automatically creating a journal entry. + + This function is called automatically when any invoice on the Castle wallet + is paid. It checks if the invoice is a Castle payment and records it in + the accounting system. + """ + # Only process Castle-specific payments + if not payment.extra or payment.extra.get("tag") != "castle": + return + + user_id = payment.extra.get("user_id") + if not user_id: + logger.warning(f"Castle invoice {payment.payment_hash} missing user_id in metadata") + return + + # Check if payment already recorded (idempotency) + from .crud import get_journal_entry_by_reference + existing = await get_journal_entry_by_reference(payment.payment_hash) + if existing: + logger.info(f"Payment {payment.payment_hash} already recorded, skipping") + return + + logger.info(f"Recording Castle payment {payment.payment_hash} for user {user_id[:8]}") + + try: + # Import here to avoid circular dependencies + from .crud import create_journal_entry, get_account_by_name, get_or_create_user_account + from .models import AccountType, CreateEntryLine, CreateJournalEntry, JournalEntryFlag + + # Convert amount from millisatoshis to satoshis + amount_sats = payment.amount // 1000 + + # Get user's receivable account (what user owes) + user_receivable = await get_or_create_user_account( + user_id, AccountType.ASSET, "Accounts Receivable" + ) + + # Get lightning account + lightning_account = await get_account_by_name("Assets:Bitcoin:Lightning") + if not lightning_account: + logger.error("Lightning account 'Assets:Bitcoin:Lightning' not found") + return + + # Create journal entry to record payment + # DR Assets:Bitcoin:Lightning, CR Assets:Receivable (User) + # This reduces what the user owes + entry_meta = { + "source": "lightning_payment", + "created_via": "auto_invoice_listener", + "payment_hash": payment.payment_hash, + "payer_user_id": user_id, + } + + entry_data = CreateJournalEntry( + description=f"Lightning payment from user {user_id[:8]}", + reference=payment.payment_hash, + flag=JournalEntryFlag.CLEARED, + meta=entry_meta, + lines=[ + CreateEntryLine( + account_id=lightning_account.id, + debit=amount_sats, + credit=0, + description="Lightning payment received", + ), + CreateEntryLine( + account_id=user_receivable.id, + debit=0, + credit=amount_sats, + description="Payment applied to balance", + ), + ], + ) + + entry = await create_journal_entry(entry_data, user_id) + logger.info(f"Successfully recorded journal entry {entry.id} for payment {payment.payment_hash}") + + except Exception as e: + logger.error(f"Error recording Castle payment {payment.payment_hash}: {e}") + raise diff --git a/views_api.py b/views_api.py index 5f6ce0d..3753f85 100644 --- a/views_api.py +++ b/views_api.py @@ -591,7 +591,7 @@ async def api_generate_payment_invoice( amount=data.amount, memo=f"Payment from user {target_user_id[:8]} to Castle", unit="sat", - extra={"user_id": target_user_id, "type": "castle_payment"}, + extra={"tag": "castle", "user_id": target_user_id}, ) payment = await create_payment_request(castle_wallet_id, invoice_data) @@ -648,6 +648,18 @@ async def api_record_payment( detail="Payment metadata missing user_id. Cannot determine which user to credit.", ) + # Check if payment already recorded (idempotency) + from .crud import get_journal_entry_by_reference + existing = await get_journal_entry_by_reference(data.payment_hash) + if existing: + # Payment already recorded, return existing entry + balance = await get_user_balance(target_user_id) + return { + "journal_entry_id": existing.id, + "new_balance": balance.balance, + "message": "Payment already recorded", + } + # Convert amount from millisatoshis to satoshis amount_sats = payment.amount // 1000