import json from datetime import datetime from typing import Optional from lnbits.db import Database from lnbits.helpers import urlsafe_short_hash from .models import ( Account, AccountType, AssertionStatus, BalanceAssertion, CastleSettings, CreateAccount, CreateBalanceAssertion, CreateEntryLine, CreateJournalEntry, EntryLine, JournalEntry, StoredUserWalletSettings, UserBalance, UserCastleSettings, UserWalletSettings, ) # Import core accounting logic from .core.balance import BalanceCalculator, AccountType as CoreAccountType from .core.inventory import CastleInventory, CastlePosition from .core.validation import ( ValidationError, validate_journal_entry, validate_balance, validate_receivable_entry, validate_expense_entry, validate_payment_entry, ) db = Database("ext_castle") # ===== ACCOUNT OPERATIONS ===== async def create_account(data: CreateAccount) -> Account: account_id = urlsafe_short_hash() account = Account( id=account_id, name=data.name, account_type=data.account_type, description=data.description, user_id=data.user_id, created_at=datetime.now(), ) await db.insert("accounts", account) return account async def get_account(account_id: str) -> Optional[Account]: return await db.fetchone( "SELECT * FROM accounts WHERE id = :id", {"id": account_id}, Account, ) async def get_account_by_name(name: str) -> Optional[Account]: """Get account by name (hierarchical format)""" return await db.fetchone( "SELECT * FROM accounts WHERE name = :name", {"name": name}, Account, ) async def get_all_accounts() -> list[Account]: return await db.fetchall( "SELECT * FROM accounts ORDER BY account_type, name", model=Account, ) async def get_accounts_by_type(account_type: AccountType) -> list[Account]: return await db.fetchall( "SELECT * FROM accounts WHERE account_type = :type ORDER BY name", {"type": account_type.value}, Account, ) async def get_or_create_user_account( user_id: str, account_type: AccountType, base_name: str ) -> Account: """ Get or create a user-specific account with hierarchical naming. Examples: get_or_create_user_account("af983632", AccountType.ASSET, "Accounts Receivable") → "Assets:Receivable:User-af983632" get_or_create_user_account("af983632", AccountType.LIABILITY, "Accounts Payable") → "Liabilities:Payable:User-af983632" """ from .account_utils import format_hierarchical_account_name # Generate hierarchical account name account_name = format_hierarchical_account_name(account_type, base_name, user_id) # Try to find existing account with this hierarchical name account = await db.fetchone( """ SELECT * FROM accounts WHERE user_id = :user_id AND account_type = :type AND name = :name """, {"user_id": user_id, "type": account_type.value, "name": account_name}, Account, ) if not account: # Create new account with hierarchical name account = await create_account( CreateAccount( name=account_name, account_type=account_type, description=f"User-specific {account_type.value} account", user_id=user_id, ) ) return account # ===== JOURNAL ENTRY OPERATIONS ===== async def create_journal_entry( data: CreateJournalEntry, created_by: str ) -> JournalEntry: entry_id = urlsafe_short_hash() # Validate that debits equal credits total_debits = sum(line.debit for line in data.lines) total_credits = sum(line.credit for line in data.lines) if total_debits != total_credits: raise ValueError( f"Journal entry must balance: debits={total_debits}, credits={total_credits}" ) entry_date = data.entry_date or datetime.now() journal_entry = JournalEntry( id=entry_id, description=data.description, entry_date=entry_date, created_by=created_by, created_at=datetime.now(), reference=data.reference, lines=[], flag=data.flag, meta=data.meta, ) # Insert journal entry without the lines field (lines are stored in entry_lines table) await db.execute( """ INSERT INTO journal_entries (id, description, entry_date, created_by, created_at, reference, flag, meta) VALUES (:id, :description, :entry_date, :created_by, :created_at, :reference, :flag, :meta) """, { "id": journal_entry.id, "description": journal_entry.description, "entry_date": journal_entry.entry_date, "created_by": journal_entry.created_by, "created_at": journal_entry.created_at, "reference": journal_entry.reference, "flag": journal_entry.flag.value, "meta": json.dumps(journal_entry.meta), }, ) # Create entry lines lines = [] for line_data in data.lines: line_id = urlsafe_short_hash() line = EntryLine( id=line_id, journal_entry_id=entry_id, account_id=line_data.account_id, debit=line_data.debit, credit=line_data.credit, description=line_data.description, metadata=line_data.metadata, ) # Insert with metadata as JSON string await db.execute( """ INSERT INTO entry_lines (id, journal_entry_id, account_id, debit, credit, description, metadata) VALUES (:id, :journal_entry_id, :account_id, :debit, :credit, :description, :metadata) """, { "id": line.id, "journal_entry_id": line.journal_entry_id, "account_id": line.account_id, "debit": line.debit, "credit": line.credit, "description": line.description, "metadata": json.dumps(line.metadata), }, ) lines.append(line) journal_entry.lines = lines return journal_entry async def get_journal_entry(entry_id: str) -> Optional[JournalEntry]: entry = await db.fetchone( "SELECT * FROM journal_entries WHERE id = :id", {"id": entry_id}, JournalEntry, ) if entry: entry.lines = await get_entry_lines(entry_id) return entry async def get_journal_entry_by_reference(reference: str) -> Optional[JournalEntry]: """Get a journal entry by its reference field (e.g., payment_hash)""" entry = await db.fetchone( "SELECT * FROM journal_entries WHERE reference = :reference", {"reference": reference}, JournalEntry, ) if entry: entry.lines = await get_entry_lines(entry.id) return entry async def get_entry_lines(journal_entry_id: str) -> list[EntryLine]: rows = await db.fetchall( "SELECT * FROM entry_lines WHERE journal_entry_id = :id", {"id": journal_entry_id}, ) lines = [] for row in rows: # Parse metadata from JSON string metadata = json.loads(row.metadata) if row.metadata else {} line = EntryLine( id=row.id, journal_entry_id=row.journal_entry_id, account_id=row.account_id, debit=row.debit, credit=row.credit, description=row.description, metadata=metadata, ) lines.append(line) return lines async def get_all_journal_entries(limit: int = 100) -> list[JournalEntry]: entries_data = await db.fetchall( """ SELECT * FROM journal_entries ORDER BY entry_date DESC, created_at DESC LIMIT :limit """, {"limit": limit}, ) entries = [] for entry_data in entries_data: # Parse flag and meta from database from .models import JournalEntryFlag flag = JournalEntryFlag(entry_data.get("flag", "*")) meta = json.loads(entry_data.get("meta", "{}")) if entry_data.get("meta") else {} entry = JournalEntry( id=entry_data["id"], description=entry_data["description"], entry_date=entry_data["entry_date"], created_by=entry_data["created_by"], created_at=entry_data["created_at"], reference=entry_data["reference"], flag=flag, meta=meta, lines=[], ) entry.lines = await get_entry_lines(entry.id) entries.append(entry) return entries async def get_journal_entries_by_user( user_id: str, limit: int = 100 ) -> list[JournalEntry]: """Get journal entries that affect the user's accounts""" # Get all user-specific accounts user_accounts = await db.fetchall( "SELECT id FROM accounts WHERE user_id = :user_id", {"user_id": user_id}, ) if not user_accounts: return [] account_ids = [acc["id"] for acc in user_accounts] # Get all journal entries that have lines affecting these accounts # Build the IN clause with named parameters placeholders = ','.join([f":account_{i}" for i in range(len(account_ids))]) params = {f"account_{i}": acc_id for i, acc_id in enumerate(account_ids)} params["limit"] = limit entries_data = await db.fetchall( f""" SELECT DISTINCT je.* FROM journal_entries je JOIN entry_lines el ON je.id = el.journal_entry_id WHERE el.account_id IN ({placeholders}) ORDER BY je.entry_date DESC, je.created_at DESC LIMIT :limit """, params, ) entries = [] for entry_data in entries_data: # Parse flag and meta from database from .models import JournalEntryFlag flag = JournalEntryFlag(entry_data.get("flag", "*")) meta = json.loads(entry_data.get("meta", "{}")) if entry_data.get("meta") else {} entry = JournalEntry( id=entry_data["id"], description=entry_data["description"], entry_date=entry_data["entry_date"], created_by=entry_data["created_by"], created_at=entry_data["created_at"], reference=entry_data["reference"], flag=flag, meta=meta, lines=[], ) entry.lines = await get_entry_lines(entry.id) entries.append(entry) return entries # ===== BALANCE AND REPORTING ===== async def get_account_balance(account_id: str) -> int: """Calculate account balance (debits - credits for assets/expenses, credits - debits for liabilities/equity/revenue) Only includes entries that are cleared (flag='*'), excludes pending/flagged/voided entries.""" result = await db.fetchone( """ SELECT COALESCE(SUM(el.debit), 0) as total_debit, COALESCE(SUM(el.credit), 0) as total_credit FROM entry_lines el JOIN journal_entries je ON el.journal_entry_id = je.id WHERE el.account_id = :id AND je.flag = '*' """, {"id": account_id}, ) if not result: return 0 account = await get_account(account_id) if not account: return 0 total_debit = result["total_debit"] total_credit = result["total_credit"] # Use core BalanceCalculator for consistent logic core_account_type = CoreAccountType(account.account_type.value) return BalanceCalculator.calculate_account_balance( total_debit, total_credit, core_account_type ) async def get_user_balance(user_id: str) -> UserBalance: """Get user's balance with the Castle (positive = castle owes user, negative = user owes castle)""" # Get all user-specific accounts user_accounts = await db.fetchall( "SELECT * FROM accounts WHERE user_id = :user_id", {"user_id": user_id}, Account, ) # Calculate balances for each account account_balances = {} account_inventories = {} for account in user_accounts: # Get satoshi balance balance = await get_account_balance(account.id) account_balances[account.id] = balance # Get all entry lines for this account to build inventory # Only include cleared entries (exclude pending/flagged/voided) entry_lines = await db.fetchall( """ SELECT el.* FROM entry_lines el JOIN journal_entries je ON el.journal_entry_id = je.id WHERE el.account_id = :account_id AND je.flag = '*' """, {"account_id": account.id}, ) # Use BalanceCalculator to build inventory from entry lines core_account_type = CoreAccountType(account.account_type.value) inventory = BalanceCalculator.build_inventory_from_entry_lines( [dict(line) for line in entry_lines], core_account_type ) account_inventories[account.id] = inventory # Use BalanceCalculator to calculate total user balance accounts_list = [ {"id": acc.id, "account_type": acc.account_type.value} for acc in user_accounts ] balance_result = BalanceCalculator.calculate_user_balance( accounts_list, account_balances, account_inventories ) return UserBalance( user_id=user_id, balance=balance_result["balance"], accounts=user_accounts, fiat_balances=balance_result["fiat_balances"], ) async def get_all_user_balances() -> list[UserBalance]: """Get balances for all users (used by castle to see who they owe)""" # Get all user-specific accounts all_accounts = await db.fetchall( "SELECT * FROM accounts WHERE user_id IS NOT NULL", {}, Account, ) # Get unique user IDs user_ids = set(account.user_id for account in all_accounts if account.user_id) # Calculate balance for each user using the refactored function user_balances = [] for user_id in user_ids: balance = await get_user_balance(user_id) # Include users with non-zero balance or fiat balances if balance.balance != 0 or balance.fiat_balances: user_balances.append(balance) return user_balances async def get_account_transactions( account_id: str, limit: int = 100 ) -> list[tuple[JournalEntry, EntryLine]]: """Get all transactions affecting a specific account""" rows = await db.fetchall( """ SELECT * FROM entry_lines WHERE account_id = :id ORDER BY id DESC LIMIT :limit """, {"id": account_id, "limit": limit}, ) transactions = [] for row in rows: # Parse metadata from JSON string metadata = json.loads(row.metadata) if row.metadata else {} line = EntryLine( id=row.id, journal_entry_id=row.journal_entry_id, account_id=row.account_id, debit=row.debit, credit=row.credit, description=row.description, metadata=metadata, ) entry = await get_journal_entry(line.journal_entry_id) if entry: transactions.append((entry, line)) return transactions # ===== SETTINGS ===== async def create_castle_settings( user_id: str, data: CastleSettings ) -> CastleSettings: settings = UserCastleSettings(**data.dict(), id=user_id) await db.insert("extension_settings", settings) return settings async def get_castle_settings(user_id: str) -> Optional[CastleSettings]: return await db.fetchone( """ SELECT * FROM extension_settings WHERE id = :user_id """, {"user_id": user_id}, CastleSettings, ) async def update_castle_settings( user_id: str, data: CastleSettings ) -> CastleSettings: settings = UserCastleSettings(**data.dict(), id=user_id) await db.update("extension_settings", settings) return settings # ===== USER WALLET SETTINGS ===== async def create_user_wallet_settings( user_id: str, data: UserWalletSettings ) -> UserWalletSettings: settings = StoredUserWalletSettings(**data.dict(), id=user_id) await db.insert("user_wallet_settings", settings) return settings async def get_user_wallet_settings(user_id: str) -> Optional[UserWalletSettings]: return await db.fetchone( """ SELECT * FROM user_wallet_settings WHERE id = :user_id """, {"user_id": user_id}, UserWalletSettings, ) async def update_user_wallet_settings( user_id: str, data: UserWalletSettings ) -> UserWalletSettings: settings = StoredUserWalletSettings(**data.dict(), id=user_id) await db.update("user_wallet_settings", settings) return settings async def get_all_user_wallet_settings() -> list[StoredUserWalletSettings]: """Get all user wallet settings""" return await db.fetchall( "SELECT * FROM user_wallet_settings ORDER BY id", {}, StoredUserWalletSettings, ) # ===== MANUAL PAYMENT REQUESTS ===== async def create_manual_payment_request( user_id: str, amount: int, description: str ) -> "ManualPaymentRequest": """Create a new manual payment request""" from .models import ManualPaymentRequest request_id = urlsafe_short_hash() request = ManualPaymentRequest( id=request_id, user_id=user_id, amount=amount, description=description, status="pending", created_at=datetime.now(), ) await db.execute( """ INSERT INTO manual_payment_requests (id, user_id, amount, description, status, created_at) VALUES (:id, :user_id, :amount, :description, :status, :created_at) """, { "id": request.id, "user_id": request.user_id, "amount": request.amount, "description": request.description, "status": request.status, "created_at": request.created_at, }, ) return request async def get_manual_payment_request(request_id: str) -> Optional["ManualPaymentRequest"]: """Get a manual payment request by ID""" from .models import ManualPaymentRequest return await db.fetchone( "SELECT * FROM manual_payment_requests WHERE id = :id", {"id": request_id}, ManualPaymentRequest, ) async def get_user_manual_payment_requests( user_id: str, limit: int = 100 ) -> list["ManualPaymentRequest"]: """Get all manual payment requests for a specific user""" from .models import ManualPaymentRequest return await db.fetchall( """ SELECT * FROM manual_payment_requests WHERE user_id = :user_id ORDER BY created_at DESC LIMIT :limit """, {"user_id": user_id, "limit": limit}, ManualPaymentRequest, ) async def get_all_manual_payment_requests( status: Optional[str] = None, limit: int = 100 ) -> list["ManualPaymentRequest"]: """Get all manual payment requests, optionally filtered by status""" from .models import ManualPaymentRequest if status: return await db.fetchall( """ SELECT * FROM manual_payment_requests WHERE status = :status ORDER BY created_at DESC LIMIT :limit """, {"status": status, "limit": limit}, ManualPaymentRequest, ) else: return await db.fetchall( """ SELECT * FROM manual_payment_requests ORDER BY created_at DESC LIMIT :limit """, {"limit": limit}, ManualPaymentRequest, ) async def approve_manual_payment_request( request_id: str, reviewed_by: str, journal_entry_id: str ) -> Optional["ManualPaymentRequest"]: """Approve a manual payment request""" from .models import ManualPaymentRequest await db.execute( """ UPDATE manual_payment_requests SET status = 'approved', reviewed_at = :reviewed_at, reviewed_by = :reviewed_by, journal_entry_id = :journal_entry_id WHERE id = :id """, { "id": request_id, "reviewed_at": datetime.now(), "reviewed_by": reviewed_by, "journal_entry_id": journal_entry_id, }, ) return await get_manual_payment_request(request_id) async def reject_manual_payment_request( request_id: str, reviewed_by: str ) -> Optional["ManualPaymentRequest"]: """Reject a manual payment request""" from .models import ManualPaymentRequest await db.execute( """ UPDATE manual_payment_requests SET status = 'rejected', reviewed_at = :reviewed_at, reviewed_by = :reviewed_by WHERE id = :id """, { "id": request_id, "reviewed_at": datetime.now(), "reviewed_by": reviewed_by, }, ) return await get_manual_payment_request(request_id) # ===== BALANCE ASSERTION OPERATIONS ===== async def create_balance_assertion( data: CreateBalanceAssertion, created_by: str ) -> BalanceAssertion: """Create a new balance assertion""" from decimal import Decimal assertion_id = urlsafe_short_hash() assertion_date = data.date if data.date else datetime.now() assertion = BalanceAssertion( id=assertion_id, date=assertion_date, account_id=data.account_id, expected_balance_sats=data.expected_balance_sats, expected_balance_fiat=data.expected_balance_fiat, fiat_currency=data.fiat_currency, tolerance_sats=data.tolerance_sats, tolerance_fiat=data.tolerance_fiat, status=AssertionStatus.PENDING, created_by=created_by, created_at=datetime.now(), ) # Manually insert with Decimal fields converted to strings await db.execute( """ INSERT INTO balance_assertions ( id, date, account_id, expected_balance_sats, expected_balance_fiat, fiat_currency, tolerance_sats, tolerance_fiat, status, created_by, created_at ) VALUES ( :id, :date, :account_id, :expected_balance_sats, :expected_balance_fiat, :fiat_currency, :tolerance_sats, :tolerance_fiat, :status, :created_by, :created_at ) """, { "id": assertion.id, "date": assertion.date, "account_id": assertion.account_id, "expected_balance_sats": assertion.expected_balance_sats, "expected_balance_fiat": str(assertion.expected_balance_fiat) if assertion.expected_balance_fiat else None, "fiat_currency": assertion.fiat_currency, "tolerance_sats": assertion.tolerance_sats, "tolerance_fiat": str(assertion.tolerance_fiat), "status": assertion.status.value, "created_by": assertion.created_by, "created_at": assertion.created_at, }, ) return assertion async def get_balance_assertion(assertion_id: str) -> Optional[BalanceAssertion]: """Get a balance assertion by ID""" from decimal import Decimal row = await db.fetchone( "SELECT * FROM balance_assertions WHERE id = :id", {"id": assertion_id}, ) if not row: return None # Parse Decimal fields from TEXT storage return BalanceAssertion( id=row["id"], date=row["date"], account_id=row["account_id"], expected_balance_sats=row["expected_balance_sats"], expected_balance_fiat=Decimal(row["expected_balance_fiat"]) if row["expected_balance_fiat"] else None, fiat_currency=row["fiat_currency"], tolerance_sats=row["tolerance_sats"], tolerance_fiat=Decimal(row["tolerance_fiat"]) if row["tolerance_fiat"] else Decimal("0"), checked_balance_sats=row["checked_balance_sats"], checked_balance_fiat=Decimal(row["checked_balance_fiat"]) if row["checked_balance_fiat"] else None, difference_sats=row["difference_sats"], difference_fiat=Decimal(row["difference_fiat"]) if row["difference_fiat"] else None, status=AssertionStatus(row["status"]), created_by=row["created_by"], created_at=row["created_at"], checked_at=row["checked_at"], ) async def get_balance_assertions( account_id: Optional[str] = None, status: Optional[AssertionStatus] = None, limit: int = 100, ) -> list[BalanceAssertion]: """Get balance assertions with optional filters""" from decimal import Decimal if account_id and status: rows = await db.fetchall( """ SELECT * FROM balance_assertions WHERE account_id = :account_id AND status = :status ORDER BY date DESC LIMIT :limit """, {"account_id": account_id, "status": status.value, "limit": limit}, ) elif account_id: rows = await db.fetchall( """ SELECT * FROM balance_assertions WHERE account_id = :account_id ORDER BY date DESC LIMIT :limit """, {"account_id": account_id, "limit": limit}, ) elif status: rows = await db.fetchall( """ SELECT * FROM balance_assertions WHERE status = :status ORDER BY date DESC LIMIT :limit """, {"status": status.value, "limit": limit}, ) else: rows = await db.fetchall( """ SELECT * FROM balance_assertions ORDER BY date DESC LIMIT :limit """, {"limit": limit}, ) assertions = [] for row in rows: assertions.append( BalanceAssertion( id=row["id"], date=row["date"], account_id=row["account_id"], expected_balance_sats=row["expected_balance_sats"], expected_balance_fiat=Decimal(row["expected_balance_fiat"]) if row["expected_balance_fiat"] else None, fiat_currency=row["fiat_currency"], tolerance_sats=row["tolerance_sats"], tolerance_fiat=Decimal(row["tolerance_fiat"]) if row["tolerance_fiat"] else Decimal("0"), checked_balance_sats=row["checked_balance_sats"], checked_balance_fiat=Decimal(row["checked_balance_fiat"]) if row["checked_balance_fiat"] else None, difference_sats=row["difference_sats"], difference_fiat=Decimal(row["difference_fiat"]) if row["difference_fiat"] else None, status=AssertionStatus(row["status"]), created_by=row["created_by"], created_at=row["created_at"], checked_at=row["checked_at"], ) ) return assertions async def check_balance_assertion(assertion_id: str) -> BalanceAssertion: """ Check a balance assertion by comparing expected vs actual balance. Updates the assertion with the check results. """ from decimal import Decimal assertion = await get_balance_assertion(assertion_id) if not assertion: raise ValueError(f"Balance assertion {assertion_id} not found") # Get actual account balance account = await get_account(assertion.account_id) if not account: raise ValueError(f"Account {assertion.account_id} not found") # Calculate balance at the assertion date actual_balance = await get_account_balance(assertion.account_id) # Get fiat balance if needed actual_fiat_balance = None if assertion.fiat_currency and account.user_id: user_balance = await get_user_balance(account.user_id) actual_fiat_balance = user_balance.fiat_balances.get(assertion.fiat_currency, Decimal("0")) # Check sats balance difference_sats = actual_balance - assertion.expected_balance_sats sats_match = abs(difference_sats) <= assertion.tolerance_sats # Check fiat balance if applicable fiat_match = True difference_fiat = None if assertion.expected_balance_fiat is not None and actual_fiat_balance is not None: difference_fiat = actual_fiat_balance - assertion.expected_balance_fiat fiat_match = abs(difference_fiat) <= assertion.tolerance_fiat # Determine overall status status = AssertionStatus.PASSED if (sats_match and fiat_match) else AssertionStatus.FAILED # Update assertion with check results await db.execute( """ UPDATE balance_assertions SET checked_balance_sats = :checked_sats, checked_balance_fiat = :checked_fiat, difference_sats = :diff_sats, difference_fiat = :diff_fiat, status = :status, checked_at = :checked_at WHERE id = :id """, { "id": assertion_id, "checked_sats": actual_balance, "checked_fiat": str(actual_fiat_balance) if actual_fiat_balance is not None else None, "diff_sats": difference_sats, "diff_fiat": str(difference_fiat) if difference_fiat is not None else None, "status": status.value, "checked_at": datetime.now(), }, ) # Return updated assertion return await get_balance_assertion(assertion_id) async def delete_balance_assertion(assertion_id: str) -> None: """Delete a balance assertion""" await db.execute( "DELETE FROM balance_assertions WHERE id = :id", {"id": assertion_id}, )