from datetime import datetime from typing import Optional from lnbits.db import Database from lnbits.helpers import urlsafe_short_hash from .models import ( Account, AccountType, CreateAccount, CreateEntryLine, CreateJournalEntry, EntryLine, JournalEntry, UserBalance, ) db = Database("ext_castle") # ===== ACCOUNT OPERATIONS ===== async def create_account(data: CreateAccount) -> Account: account_id = urlsafe_short_hash() account = Account( id=account_id, name=data.name, account_type=data.account_type, description=data.description, user_id=data.user_id, created_at=datetime.now(), ) await db.insert("accounts", account) return account async def get_account(account_id: str) -> Optional[Account]: return await db.fetchone( "SELECT * FROM accounts WHERE id = :id", {"id": account_id}, Account, ) async def get_account_by_name(name: str) -> Optional[Account]: return await db.fetchone( "SELECT * FROM accounts WHERE name = :name", {"name": name}, Account, ) async def get_all_accounts() -> list[Account]: return await db.fetchall( "SELECT * FROM accounts ORDER BY account_type, name", model=Account, ) async def get_accounts_by_type(account_type: AccountType) -> list[Account]: return await db.fetchall( "SELECT * FROM accounts WHERE account_type = :type ORDER BY name", {"type": account_type.value}, Account, ) async def get_or_create_user_account( user_id: str, account_type: AccountType, base_name: str ) -> Account: """Get or create a user-specific account (e.g., 'Accounts Payable - User123')""" account_name = f"{base_name} - {user_id[:8]}" account = await db.fetchone( """ SELECT * FROM accounts WHERE user_id = :user_id AND account_type = :type AND name = :name """, {"user_id": user_id, "type": account_type.value, "name": account_name}, Account, ) if not account: account = await create_account( CreateAccount( name=account_name, account_type=account_type, description=f"User-specific {account_type.value} account", user_id=user_id, ) ) return account # ===== JOURNAL ENTRY OPERATIONS ===== async def create_journal_entry( data: CreateJournalEntry, created_by: str ) -> JournalEntry: entry_id = urlsafe_short_hash() # Validate that debits equal credits total_debits = sum(line.debit for line in data.lines) total_credits = sum(line.credit for line in data.lines) if total_debits != total_credits: raise ValueError( f"Journal entry must balance: debits={total_debits}, credits={total_credits}" ) entry_date = data.entry_date or datetime.now() journal_entry = JournalEntry( id=entry_id, description=data.description, entry_date=entry_date, created_by=created_by, created_at=datetime.now(), reference=data.reference, lines=[], ) await db.insert("journal_entries", journal_entry) # Create entry lines lines = [] for line_data in data.lines: line_id = urlsafe_short_hash() line = EntryLine( id=line_id, journal_entry_id=entry_id, account_id=line_data.account_id, debit=line_data.debit, credit=line_data.credit, description=line_data.description, ) await db.insert("entry_lines", line) lines.append(line) journal_entry.lines = lines return journal_entry async def get_journal_entry(entry_id: str) -> Optional[JournalEntry]: entry = await db.fetchone( "SELECT * FROM journal_entries WHERE id = :id", {"id": entry_id}, JournalEntry, ) if entry: entry.lines = await get_entry_lines(entry_id) return entry async def get_entry_lines(journal_entry_id: str) -> list[EntryLine]: return await db.fetchall( "SELECT * FROM entry_lines WHERE journal_entry_id = :id", {"id": journal_entry_id}, EntryLine, ) async def get_all_journal_entries(limit: int = 100) -> list[JournalEntry]: entries = await db.fetchall( """ SELECT * FROM journal_entries ORDER BY entry_date DESC, created_at DESC LIMIT :limit """, {"limit": limit}, JournalEntry, ) for entry in entries: entry.lines = await get_entry_lines(entry.id) return entries async def get_journal_entries_by_user( user_id: str, limit: int = 100 ) -> list[JournalEntry]: entries = await db.fetchall( """ SELECT * FROM journal_entries WHERE created_by = :user_id ORDER BY entry_date DESC, created_at DESC LIMIT :limit """, {"user_id": user_id, "limit": limit}, JournalEntry, ) for entry in entries: entry.lines = await get_entry_lines(entry.id) return entries # ===== BALANCE AND REPORTING ===== async def get_account_balance(account_id: str) -> int: """Calculate account balance (debits - credits for assets/expenses, credits - debits for liabilities/equity/revenue)""" result = await db.fetchone( """ SELECT COALESCE(SUM(debit), 0) as total_debit, COALESCE(SUM(credit), 0) as total_credit FROM entry_lines WHERE account_id = :id """, {"id": account_id}, ) if not result: return 0 account = await get_account(account_id) if not account: return 0 total_debit = result["total_debit"] total_credit = result["total_credit"] # Normal balance for each account type: # Assets and Expenses: Debit balance (debit - credit) # Liabilities, Equity, and Revenue: Credit balance (credit - debit) if account.account_type in [AccountType.ASSET, AccountType.EXPENSE]: return total_debit - total_credit else: return total_credit - total_debit async def get_user_balance(user_id: str) -> UserBalance: """Get user's balance with the Castle (positive = castle owes user, negative = user owes castle)""" # Get all user-specific accounts user_accounts = await db.fetchall( "SELECT * FROM accounts WHERE user_id = :user_id", {"user_id": user_id}, Account, ) total_balance = 0 for account in user_accounts: balance = await get_account_balance(account.id) # If it's a liability account (castle owes user), it's positive # If it's an asset account (user owes castle), it's negative if account.account_type == AccountType.LIABILITY: total_balance += balance elif account.account_type == AccountType.ASSET: total_balance -= balance # Equity contributions are tracked but don't affect what castle owes return UserBalance( user_id=user_id, balance=total_balance, accounts=user_accounts, ) async def get_account_transactions( account_id: str, limit: int = 100 ) -> list[tuple[JournalEntry, EntryLine]]: """Get all transactions affecting a specific account""" lines = await db.fetchall( """ SELECT * FROM entry_lines WHERE account_id = :id ORDER BY id DESC LIMIT :limit """, {"id": account_id, "limit": limit}, EntryLine, ) transactions = [] for line in lines: entry = await get_journal_entry(line.journal_entry_id) if entry: transactions.append((entry, line)) return transactions